diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..cb0de790 --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +[flake8] +max-line-length = 88 +filename = *.py +exclude = + .git, + __pycache__, + docs/* +# F403: unable to detect undefined names +# F405: undefined, or defined from star imports +# W503: Linebreak before binary operator +ignore = F403, F405, W503 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 586911b1..f7018e62 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -6,22 +6,31 @@ labels: bug assignees: '' --- - + ... - + + ```python -... +# paste your code here, if applicable ``` - + +
- Error + + Error output + ```pytb -... +# paste the error output here, if applicable ```
-#### Versions: - -> + + +
Versions + +```pytb +# paste the ouput of scv.logging.print_versions() here +``` +
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 72ea68ca..47a36de4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: CI on: push: - branches: [master, develop] + branches: [master, develop, 'release/**'] pull_request: - branches: [master, develop] + branches: [master, develop, 'release/**'] jobs: # Skip CI if commit message contains `[ci skip]` in the subject @@ -19,7 +19,7 @@ jobs: - id: ci-skip-step uses: mstachniuk/ci-skip@master - # Check if code agrees with `black` and if README.rst can be converted to HTML + # Check if pre-commit hooks pass and if README.rst can be converted to HTML linting: needs: init if: ${{ needs.init.outputs.skip == 'false' }} @@ -33,9 +33,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install black>=20.8b1 docutils - - name: Check code style - run: black --check --diff --color . + pip install pre-commit docutils + - name: Check pre-commit compatibility + run: pre-commit run --all-files --show-diff-on-failure - name: Run rst2html.py run: rst2html.py --halt=2 README.rst >/dev/null @@ -55,6 +55,6 @@ jobs: - name: Install dependencies run: | pip install -e . - pip install pytest pytest-cov + pip install hypothesis pytest pytest-cov - name: Unit tests run: python -m pytest --cov=scvelo diff --git a/.gitignore b/.gitignore index 166a74fa..56a467b5 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ scripts/ write/ scanpy*/ benchmarking/ +htmlcov/ scvelo.egg-info/ @@ -37,4 +38,7 @@ docs/source/scvelo* /dist/ .coverage -.eggs \ No newline at end of file +.eggs + +# Files generated by unit tests +.hypothesis/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 14aa0ee4..3c5e1dd8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,3 +3,18 @@ repos: rev: 20.8b1 hooks: - id: black +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.4 + hooks: + - id: flake8 +- repo: https://github.com/pycqa/isort + rev: 5.7.0 + hooks: + - id: isort + name: isort (python) + - id: isort + name: isort (cython) + types: [cython] + - id: isort + name: isort (pyi) + types: [pyi] diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst new file mode 100644 index 00000000..06b0d014 --- /dev/null +++ b/CONTRIBUTING.rst @@ -0,0 +1,78 @@ +Contributing guide +================== + + +Getting started +^^^^^^^^^^^^^^^ + +Contributing to scVelo requires a developer installation. As a first step, we suggest creating a new environment + +.. code:: bash + + conda create -n ENV_NAME python=PYTHON_VERSION && conda activate ENV_NAME + + +Following, fork the scVelo repo on GitHub `here `. +If you are unsure on how to do so, please checkout the corresponding +`GitHub docs `. +You can now clone your fork of scVelo and install the development mode + +.. code:: bash + + git clone https://github.com/YOUR-USER-NAME/scvelo.git + cd scvelo + git checkout --track origin/develop + pip install -e '.[dev]' + +The last line can, alternatively, be replaced by + +.. code:: bash + + pip install -r requirements-dev.txt + + +Finally, to make sure your code follows our code style guideline, install pre-commit: + +.. code:: bash + + pre-commit install + + +Coding style +^^^^^^^^^^^^ + +Our code follows `black` and `flake8` coding style. Code formatting (`black`, `isort`) is automated through pre-commit hooks. In addition, we require that + +- functions are fully type-annotated. +- variables referred to in an error/warning message or docstrings are enclosed in \`\`. + + +Testing +^^^^^^^ + +To run the implemented unit tests locally, simply run + +.. code:: bash + + python -m pytest + + +Documentation +^^^^^^^^^^^^^ + +The docstrings of scVelo largely follow the `numpy`-style. New docstrings should + +- include neither type hints nor return types. +- reference an argument within the same docstrings using \`\`. + + +Submitting pull requests +^^^^^^^^^^^^^^^^^^^^^^^^ + +New features and bug fixes are added to the code base through a pull request (PR). To implement a feature or bug fix, create a branch from `develop`. For hotfixes use `master` as base. The existence of bugs suggests insufficient test coverage. As such, bug fixes should, ideally, include a unit test or extend an existing one. Please ensure that + +- branch names have the prefix `feat/`, `bug/` or `hotfix/`. +- your code follows the project conventions. +- newly added functions are unit tested. +- all tests pass locally. +- if there is no issue solved by the PR, create one outlining what you try to add/solve and reference it in the PR description. diff --git a/README.rst b/README.rst index 46be45ab..e317e25b 100644 --- a/README.rst +++ b/README.rst @@ -34,6 +34,7 @@ patients and dynamic processes in human lung regeneration. Find out more in this Latest news ^^^^^^^^^^^ +- Aug/2021: `Perspectives paper out in MSB `_ - Feb/2021: scVelo goes multi-core - Dec/2020: Cover of `Nature Biotechnology `_ - Nov/2020: Talk at `Single Cell Biology `_ @@ -42,11 +43,15 @@ Latest news - Sep/2020: Talk at `Single Cell Omics `_ - Aug/2020: `scVelo out in Nature Biotech `_ -Reference -^^^^^^^^^ +References +^^^^^^^^^^ +La Manno *et al.* (2018), RNA velocity of single cells, `Nature `_. + Bergen *et al.* (2020), Generalizing RNA velocity to transient cell states through dynamical modeling, `Nature Biotech `_. -|dim| + +Bergen *et al.* (2021), RNA velocity - current challenges and future perspectives, +`Molecular Systems Biology `_. Support ^^^^^^^ diff --git a/docs/requirements.txt b/docs/requirements.txt index 3f3b35b5..99078be6 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,5 +9,5 @@ sphinx_autodoc_typehints<=1.6 # converting notebooks to html ipykernel -sphinx>=1.7 -nbsphinx>=0.7 \ No newline at end of file +sphinx>=1.7,<4.0 +nbsphinx>=0.7,<0.8.7 \ No newline at end of file diff --git a/docs/source/_ext/edit_on_github.py b/docs/source/_ext/edit_on_github.py index 720b69f1..6bd7852a 100644 --- a/docs/source/_ext/edit_on_github.py +++ b/docs/source/_ext/edit_on_github.py @@ -5,7 +5,6 @@ import os import warnings - __licence__ = "BSD (3 clause)" diff --git a/docs/source/api.rst b/docs/source/api.rst index e3656a37..e4215c57 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -153,6 +153,12 @@ Datasets datasets.pancreas datasets.dentategyrus datasets.forebrain + datasets.dentategyrus_lamanno + datasets.gastrulation + datasets.gastrulation_e75 + datasets.gastrulation_erythroid + datasets.bonemarrow + datasets.pbmc68k datasets.simulation diff --git a/docs/source/conf.py b/docs/source/conf.py index 2ee05f04..64e7750b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,15 +1,31 @@ -import sys -import os import inspect import logging -from pathlib import Path +import os +import sys from datetime import datetime -from typing import Optional, Union, Mapping +from pathlib import Path, PurePosixPath +from typing import Dict, List, Mapping, Optional, Tuple, Union +from urllib.request import urlretrieve +import sphinx_autodoc_typehints +from docutils import nodes +from jinja2.defaults import DEFAULT_FILTERS +from sphinx import addnodes from sphinx.application import Sphinx +from sphinx.domains.python import PyObject, PyTypedField +from sphinx.environment import BuildEnvironment from sphinx.ext import autosummary +import matplotlib # noqa + +HERE = Path(__file__).parent +sys.path.insert(0, str(HERE.parent.parent)) +sys.path.insert(0, os.path.abspath("_ext")) + +import scvelo # isort:skip + # remove PyCharm’s old six module + if "six" in sys.modules: print(*sys.path, sep="\n") for pypath in list(sys.path): @@ -17,34 +33,45 @@ sys.path.remove(pypath) del sys.modules["six"] -import matplotlib # noqa - matplotlib.use("agg") -HERE = Path(__file__).parent -sys.path.insert(0, f"{HERE.parent.parent}") -sys.path.insert(0, os.path.abspath("_ext")) -import scvelo - logger = logging.getLogger(__name__) -# -- Retrieve notebooks ------------------------------------------------ - -from urllib.request import urlretrieve +# -- Basic notebooks and those stored under /vignettes and /perspectives -- notebooks_url = "https://github.com/theislab/scvelo_notebooks/raw/master/" -notebooks = [ +notebooks = [] +notebook = [ "VelocityBasics.ipynb", "DynamicalModeling.ipynb", "DifferentialKinetics.ipynb", +] +notebooks.extend(notebook) + +notebook = [ "Pancreas.ipynb", "DentateGyrus.ipynb", + "NatureBiotechCover.ipynb", + "Fig1_concept.ipynb", + "Fig2_dentategyrus.ipynb", + "Fig3_pancreas.ipynb", + "FigS9_runtime.ipynb", + "FigSuppl.ipynb", ] +notebooks.extend([f"vignettes/{nb}" for nb in notebook]) + +notebook = ["Perspectives.ipynb", "Perspectives_parameters.ipynb"] +notebooks.extend([f"perspectives/{nb}" for nb in notebook]) + +# -- Retrieve all notebooks -- + for nb in notebooks: + url = notebooks_url + nb try: - urlretrieve(notebooks_url + nb, nb) - except: + urlretrieve(url, nb) + except Exception as e: + logger.error(f"Unable to retrieve notebook: `{url}`. Reason: `{e}`") pass @@ -218,7 +245,6 @@ def get_linenos(obj): github_url_read_loom = "https://github.com/theislab/anndata/tree/master/anndata" github_url_read = "https://github.com/theislab/scanpy/tree/master" github_url_scanpy = "https://github.com/theislab/scanpy/tree/master/scanpy" -from pathlib import PurePosixPath def modurl(qualname): @@ -253,14 +279,11 @@ def api_image(qualname: str) -> Optional[str]: # modify the default filters -from jinja2.defaults import DEFAULT_FILTERS DEFAULT_FILTERS.update(modurl=modurl, api_image=api_image) # -- Override some classnames in autodoc -------------------------------------------- -import sphinx_autodoc_typehints - qualname_overrides = { "anndata.base.AnnData": "anndata.AnnData", "scvelo.pl.scatter": "scvelo.plotting.scatter", @@ -292,12 +315,6 @@ def format_annotation(annotation): # -- Prettier Param docs -------------------------------------------- -from typing import Dict, List, Tuple -from docutils import nodes -from sphinx import addnodes -from sphinx.domains.python import PyTypedField, PyObject -from sphinx.environment import BuildEnvironment - class PrettyTypedField(PyTypedField): list_type = nodes.definition_list diff --git a/docs/source/index.rst b/docs/source/index.rst index 8dc7eddf..372bf016 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,6 +32,7 @@ patients and dynamic processes in human lung regeneration. Find out more in this Latest news ^^^^^^^^^^^ +- Aug/2021: `Perspectives paper out in MSB `_ - Feb/2021: scVelo goes multi-core - Dec/2020: Cover of `Nature Biotechnology `_ - Nov/2020: Talk at `Single Cell Biology `_ @@ -40,11 +41,15 @@ Latest news - Sep/2020: Talk at `Single Cell Omics `_ - Aug/2020: `scVelo out in Nature Biotech `_ -Reference -^^^^^^^^^ +References +^^^^^^^^^^ +La Manno *et al.* (2018), RNA velocity of single cells, `Nature `_. + Bergen *et al.* (2020), Generalizing RNA velocity to transient cell states through dynamical modeling, `Nature Biotech `_. -|dim| + +Bergen *et al.* (2021), RNA velocity - current challenges and future perspectives, +`Molecular Systems Biology `_. Support ^^^^^^^ @@ -77,14 +82,15 @@ For further information visit `scvelo.org `_. VelocityBasics DynamicalModeling DifferentialKinetics + vignettes/index .. toctree:: - :caption: Example Datasets + :caption: Perspectives :maxdepth: 1 :hidden: - Pancreas - DentateGyrus + perspectives/index + .. |PyPI| image:: https://img.shields.io/pypi/v/scvelo.svg :target: https://pypi.org/project/scvelo diff --git a/docs/source/installation.rst b/docs/source/installation.rst index e1901192..d2245a8b 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -19,12 +19,13 @@ Development Version To work with the latest development version, install from GitHub_ using:: - pip install git+https://github.com/theislab/scvelo + pip install git+https://github.com/theislab/scvelo@develop or:: - git clone https://github.com/theislab/scvelo - pip install -e scvelo + git clone https://github.com/theislab/scvelo && cd scvelo + git checkout --track origin/develop + pip install -e . ``-e`` is short for ``--editable`` and links the package to the original cloned location such that pulled changes are also reflected in the environment. diff --git a/docs/source/perspectives/index.rst b/docs/source/perspectives/index.rst new file mode 100644 index 00000000..4460ec37 --- /dev/null +++ b/docs/source/perspectives/index.rst @@ -0,0 +1,34 @@ +Challenges and Perspectives +--------------------------- + +This page complements our manuscript +`Bergen et al. (MSB, 2021) RNA velocity - Current challenges and future perspectives `_ + +We provide several examples to discuss potential pitfalls of current RNA velocity +modeling approaches, and provide guidance on how the ensuing challenges may be addressed. +Our aspiration is to suggest promising future directions and to stimulate a communities effort on further model extensions. + +In the following, you find two vignettes with several use cases, as well as an in-depth analysis of time-dependent kinetic rate parameters. + +Potential pitfalls +^^^^^^^^^^^^^^^^^^ +.. image:: https://user-images.githubusercontent.com/31883718/115840357-e354f480-a41b-11eb-8a95-12f7564fd9b0.png + :width: 300px + :align: left + +This notebook reproduces Fig. 2 with several use cases, including multiple kinetics in Dentate Gyrus, +transcriptional boost in erythroid lineage, and misleading arrow projections in mature PBMCs. + +Notebook: `Perspectives `_ + +| +Kinetic parameter analysis +^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. image:: https://user-images.githubusercontent.com/31883718/130656606-00bd44be-9071-4008-be1b-244fa9c2d244.png + :width: 300px + :align: left + +This notebook reproduces Fig. 3, where we demonstrate how time-variable kinetic rates +shape the curvature patterns of gene activation. + +Notebook: `Kinetic parameter analysis `_ diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index a62bda49..5e812e86 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -4,6 +4,30 @@ Release Notes ============= +Version 0.2.4 :small:`Aug 26, 2021` +----------------------------------- + +Perspectives: + +- Landing page and two notebooks accompanying the perspectives manuscript at MSB. +- New datasets: Gastrulation, bone marrow, and PBMCs. + +New capabilities: + +- Added vignettes accompanying the NBT manuscript. +- Kinetic simulations with time-dependent rates. +- New arguments for `tl.velocity_embedding_stream` (`PR 492 `_). +- Introduced automated code formatting `flake8` and `isort` (`PR 360 `_, `PR 374 `_). +- `tl.velocity_graph` parallelized (`PR 392 `_). +- `legend_align_text` parameter in `pl.scatter` for smart placing of labels without overlapping. +- Save option for `pl.proportions`. + +Bugfixes: + +- Pinned `sphinx<4.0` and `nbsphinx<0.8.7`. +- Fix IPython import at CLI. + + Version 0.2.3 :small:`Feb 13, 2021` ----------------------------------- @@ -140,4 +164,4 @@ Version 0.1.5 :small:`Sep 4, 2018` Version 0.1.2 :small:`Aug 21, 2018` ----------------------------------- -First alpha release of scvelo. \ No newline at end of file +First alpha release of scvelo. diff --git a/docs/source/vignettes/index.rst b/docs/source/vignettes/index.rst new file mode 100644 index 00000000..87f6735e --- /dev/null +++ b/docs/source/vignettes/index.rst @@ -0,0 +1,19 @@ +Other Vignettes +--------------- + +Example Datasets +^^^^^^^^^^^^^^^^ +- `Dentate Gyrus `_ +- `Pancreas `_ + +Nature Biotech Figures +^^^^^^^^^^^^^^^^^^^^^^ + +- NBT Cover | `cover `__ | `notebook `__ +- Fig.1 Concept | `figure `__ | `notebook `__ +- Fig.2 Dentate Gyrus | `figure `__ | `notebook `__ +- Fig.3 Pancreas | `figure `__ | `notebook `__ +- Suppl. Figures | `figure `__ | `notebook `__ | `runtime `__ + +All notebooks are deposited at `GitHub `_. +Found a bug? Feel free to submit an `issue `_. \ No newline at end of file diff --git a/pypi.rst b/pypi.rst index b2833154..9170fed4 100644 --- a/pypi.rst +++ b/pypi.rst @@ -32,6 +32,7 @@ patients and dynamic processes in human lung regeneration. Find out more in this Latest news ^^^^^^^^^^^ +- Aug/2021: `Perspectives paper out in MSB `_ - Feb/2021: scVelo goes multi-core - Dec/2020: Cover of `Nature Biotechnology `_ - Nov/2020: Talk at `Single Cell Biology `_ @@ -40,11 +41,16 @@ Latest news - Sep/2020: Talk at `Single Cell Omics `_ - Aug/2020: `scVelo out in Nature Biotech `_ -Reference -^^^^^^^^^ +References +^^^^^^^^^^ +La Manno *et al.* (2018), RNA velocity of single cells, `Nature `_. + Bergen *et al.* (2020), Generalizing RNA velocity to transient cell states through dynamical modeling, `Nature Biotech `_. +Bergen *et al.* (2021), RNA velocity - current challenges and future perspectives, +`Molecular Systems Biology `_. + Support ^^^^^^^ Found a bug or would like to see a feature implemented? Feel free to submit an diff --git a/pyproject.toml b/pyproject.toml index 381b4762..cbf07ea1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,3 +17,17 @@ exclude = ''' | dist )/ ''' + +[tool.isort] +profile = "black" +use_parentheses = true +known_num = "networkx,numpy,pandas,scipy,sklearn,statmodels" +known_plot = "matplotlib,mpl_toolkits,seaborn" +known_bio = "anndata,scanpy" +sections = "FUTURE,STDLIB,THIRDPARTY,NUM,PLOT,BIO,FIRSTPARTY,LOCALFOLDER" +no_lines_before = "LOCALFOLDER" +balanced_wrapping = true +length_sort = "0" +indent = " " +float_to_top = true +order_by_type = false diff --git a/pytest.ini b/pytest.ini index a12918d6..67a1c148 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,6 @@ [pytest] -python_files = *.py -testpaths = tests +python_files = test_*.py +testpaths = + tests + scvelo xfail_strict = true diff --git a/requirements-dev.txt b/requirements-dev.txt index 1bd89b6c..af20ddc2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,4 +3,12 @@ -e . black==20.8b1 -pre-commit==2.5.1 +hnswlib +hypothesis +flake8==3.8.4 +isort==5.7.0 +louvain +pre-commit>=2.9.0 +pybind11 +pytest>=6.2.2 +python-igraph diff --git a/requirements.txt b/requirements.txt index 5831c94e..c1ad4bc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +typing_extensions + # main requirements for single-cell analysis anndata>=0.7.5 # compatible with h5py v3.0 (v0.7.5) scanpy>=1.5 # adapt to new anndata attributes (v1.5) @@ -6,11 +8,12 @@ scanpy>=1.5 # adapt to new anndata attributes (v1.5) loompy>=2.0.12 # introduced sparsity support (v2.0.12) # for computing neighbor graph connectivities -umap-learn>=0.3.10, <0.5 # removed numba warnings (v0.3.10) +umap-learn>=0.3.10 # removed numba warnings (v0.3.10) # standard requirements for data analysis +numba>=0.41.0 numpy>=1.17 # extension/speedup in .nan_to_num, .exp (v1.17) scipy>=1.4.1 # introduced PCA sparsity support (v1.4) pandas>=0.23 # merging/sorting extensions (v0.23) scikit-learn>=0.21.2 # bugfix in .utils.sparsefuncs (v0.21.2) -matplotlib>=3.1.2 # several bugfixes (v3.1.2) \ No newline at end of file +matplotlib>=3.3.0 # normalize in pie (v3.3.0) \ No newline at end of file diff --git a/scvelo/__init__.py b/scvelo/__init__.py index 6df0f0ac..76cc70d1 100644 --- a/scvelo/__init__.py +++ b/scvelo/__init__.py @@ -1,4 +1,17 @@ """scvelo - RNA velocity generalized through dynamical modeling""" +from anndata import AnnData +from scanpy import read, read_loom + +from scvelo import datasets, logging, pl, pp, settings, tl, utils +from scvelo.core import get_df +from scvelo.plotting.gridspec import GridSpec +from scvelo.preprocessing.neighbors import Neighbors +from scvelo.read_load import DataFrame, load, read_csv +from scvelo.settings import set_figure_params +from scvelo.tools.run import run_all, test +from scvelo.tools.utils import round +from scvelo.tools.velocity import Velocity +from scvelo.tools.velocity_graph import VelocityGraph try: from setuptools_scm import get_version @@ -8,24 +21,33 @@ except (LookupError, ImportError): try: from importlib_metadata import version # Python < 3.8 - except: + except Exception: from importlib.metadata import version # Python = 3.8 __version__ = version(__name__) del version -from .read_load import AnnData, read, read_loom, load, read_csv, get_df, DataFrame -from .preprocessing.neighbors import Neighbors -from .tools.run import run_all, test -from .tools.utils import round -from .tools.velocity import Velocity -from .tools.velocity_graph import VelocityGraph -from .plotting.gridspec import GridSpec -from .settings import set_figure_params -from . import pp -from . import tl -from . import pl -from . import utils -from . import datasets -from . import logging -from . import settings +__all__ = [ + "AnnData", + "DataFrame", + "datasets", + "get_df", + "GridSpec", + "load", + "logging", + "Neighbors", + "pl", + "pp", + "read", + "read_csv", + "read_loom", + "round", + "run_all", + "set_figure_params", + "settings", + "test", + "tl", + "utils", + "Velocity", + "VelocityGraph", +] diff --git a/scvelo/core/__init__.py b/scvelo/core/__init__.py new file mode 100644 index 00000000..156eeb6e --- /dev/null +++ b/scvelo/core/__init__.py @@ -0,0 +1,43 @@ +from ._anndata import ( + clean_obs_names, + cleanup, + get_df, + get_initial_size, + get_modality, + get_size, + make_dense, + make_sparse, + merge, + set_initial_size, + set_modality, + show_proportions, +) +from ._arithmetic import clipped_log, invert, prod_sum, sum +from ._linear_models import LinearRegression +from ._metrics import l2_norm +from ._models import SplicingDynamics +from ._parallelize import get_n_jobs, parallelize + +__all__ = [ + "clean_obs_names", + "cleanup", + "clipped_log", + "get_df", + "get_initial_size", + "get_modality", + "get_n_jobs", + "get_size", + "invert", + "l2_norm", + "LinearRegression", + "make_dense", + "make_sparse", + "merge", + "parallelize", + "prod_sum", + "set_initial_size", + "set_modality", + "show_proportions", + "SplicingDynamics", + "sum", +] diff --git a/scvelo/core/_anndata.py b/scvelo/core/_anndata.py new file mode 100644 index 00000000..d77bbb7d --- /dev/null +++ b/scvelo/core/_anndata.py @@ -0,0 +1,790 @@ +import re +from typing import List, Optional, Union + +from typing_extensions import Literal + +import numpy as np +import pandas as pd +from numpy import ndarray +from pandas import DataFrame +from pandas.api.types import is_categorical_dtype +from scipy.sparse import csr_matrix, issparse, spmatrix + +from anndata import AnnData + +from scvelo import logging as logg +from ._arithmetic import sum + + +def clean_obs_names( + data: AnnData, + base: str = "[AGTCBDHKMNRSVWY]", + ID_length: int = 12, + copy: bool = False, +) -> Optional[AnnData]: + """Clean up the obs_names. + + For example an obs_name 'sample1_AGTCdate' is changed to 'AGTC' of the sample + 'sample1_date'. The sample name is then saved in obs['sample_batch']. + The genetic codes are identified according to according to + https://www.neb.com/tools-and-resources/usage-guidelines/the-genetic-code. + + Arguments + --------- + data + Annotated data matrix. + base + Genetic code letters to be identified. + ID_length + Length of the Genetic Codes in the samples. + copy + Return a copy instead of writing to adata. + + Returns + ------- + Optional[AnnData] + Returns or updates `adata` with the attributes + obs_names: list + updated names of the observations + sample_batch: `.obs` + names of the identified sample batches + """ + + def get_base_list(name, base): + base_list = base + while re.search(base_list + base, name) is not None: + base_list += base + if len(base_list) == 0: + raise ValueError("Encountered an invalid ID in obs_names: ", name) + return base_list + + adata = data.copy() if copy else data + + names = adata.obs_names + base_list = get_base_list(names[0], base) + + if len(np.unique([len(name) for name in adata.obs_names])) == 1: + start, end = re.search(base_list, names[0]).span() + newIDs = [name[start:end] for name in names] + start, end = 0, len(newIDs[0]) + for i in range(end - ID_length): + if np.any([ID[i] not in base for ID in newIDs]): + start += 1 + if np.any([ID[::-1][i] not in base for ID in newIDs]): + end -= 1 + + newIDs = [ID[start:end] for ID in newIDs] + prefixes = [names[i].replace(newIDs[i], "") for i in range(len(names))] + else: + prefixes, newIDs = [], [] + for name in names: + match = re.search(base_list, name) + newID = ( + re.search(get_base_list(name, base), name).group() + if match is None + else match.group() + ) + newIDs.append(newID) + prefixes.append(name.replace(newID, "")) + + adata.obs_names = newIDs + if len(prefixes[0]) > 0 and len(np.unique(prefixes)) > 1: + adata.obs["sample_batch"] = ( + pd.Categorical(prefixes) + if len(np.unique(prefixes)) < adata.n_obs + else prefixes + ) + + adata.obs_names_make_unique() + return adata if copy else None + + +def cleanup( + data: AnnData, + clean: Union[ + Literal["layers", "obs", "var", "uns"], + List[Literal["layers", "obs", "var", "uns"]], + ] = "layers", + keep: Optional[Union[str, List[str]]] = None, + copy: bool = False, +) -> Optional[AnnData]: + """Delete not needed attributes. + + Arguments + --------- + data + Annotated data matrix. + clean + Which attributes to consider for freeing memory. + keep + Which attributes to keep. + copy + Return a copy instead of writing to adata. + + Returns + ------- + Optional[AnnData] + Returns or updates `adata` with selection of attributes kept. + """ + + adata = data.copy() if copy else data + verify_dtypes(adata) + + keep = list([keep] if isinstance(keep, str) else {} if keep is None else keep) + keep.extend(["spliced", "unspliced", "Ms", "Mu", "clusters", "neighbors"]) + + ann_dict = { + "obs": adata.obs_keys(), + "var": adata.var_keys(), + "uns": adata.uns_keys(), + "layers": list(adata.layers.keys()), + } + + if "all" not in clean: + ann_dict = {ann: values for (ann, values) in ann_dict.items() if ann in clean} + + for (ann, values) in ann_dict.items(): + for value in values: + if value not in keep: + del getattr(adata, ann)[value] + + return adata if copy else None + + +def get_df( + data: AnnData, + keys: Optional[Union[str, List[str]]] = None, + layer: Optional[str] = None, + index: List = None, + columns: List = None, + sort_values: bool = None, + dropna: Literal["all", "any"] = "all", + precision: int = None, +) -> DataFrame: + """Get dataframe for a specified adata key. + + Return values for specified key + (in obs, var, obsm, varm, obsp, varp, uns, or layers) as a dataframe. + + Arguments + --------- + data + AnnData object or a numpy array to get values from. + keys + Keys from `.var_names`, `.obs_names`, `.var`, `.obs`, + `.obsm`, `.varm`, `.obsp`, `.varp`, `.uns`, or `.layers`. + layer + Layer of `adata` to use as expression values. + index + List to set as index. + columns + List to set as columns names. + sort_values + Wether to sort values by first column (sort_values=True) or a specified column. + dropna + Drop columns/rows that contain NaNs in all ('all') or in any entry ('any'). + precision + Set precision for pandas dataframe. + + Returns + ------- + :class:`pd.DataFrame` + A dataframe. + """ + + if precision is not None: + pd.set_option("precision", precision) + + if isinstance(data, AnnData): + keys, keys_split = ( + keys.split("*") if isinstance(keys, str) and "*" in keys else (keys, None) + ) + keys, key_add = ( + keys.split("/") if isinstance(keys, str) and "/" in keys else (keys, None) + ) + keys = [keys] if isinstance(keys, str) else keys + key = keys[0] + + s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"] + d_keys = [ + data.obs.keys(), + data.var.keys(), + data.obsm.keys(), + data.varm.keys(), + data.uns.keys(), + data.layers.keys(), + ] + + if hasattr(data, "obsp") and hasattr(data, "varp"): + s_keys.extend(["obsp", "varp"]) + d_keys.extend([data.obsp.keys(), data.varp.keys()]) + + if keys is None: + df = data.to_df() + elif key in data.var_names: + df = obs_df(data, keys, layer=layer) + elif key in data.obs_names: + df = var_df(data, keys, layer=layer) + else: + if keys_split is not None: + keys = [ + k + for k in list(data.obs.keys()) + list(data.var.keys()) + if key in k and keys_split in k + ] + key = keys[0] + s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key] + if len(s_key) == 0: + raise ValueError(f"'{key}' not found in any of {', '.join(s_keys)}.") + if len(s_key) > 1: + logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.") + + s_key = s_key[-1] + df = getattr(data, s_key)[keys if len(keys) > 1 else key] + if key_add is not None: + df = df[key_add] + if index is None: + index = ( + data.var_names + if s_key == "varm" + else data.obs_names + if s_key in {"obsm", "layers"} + else None + ) + if index is None and s_key == "uns" and hasattr(df, "shape"): + key_cats = np.array( + [ + key + for key in data.obs.keys() + if is_categorical_dtype(data.obs[key]) + ] + ) + num_cats = [ + len(data.obs[key].cat.categories) == df.shape[0] + for key in key_cats + ] + if np.sum(num_cats) == 1: + index = data.obs[key_cats[num_cats][0]].cat.categories + if ( + columns is None + and len(df.shape) > 1 + and df.shape[0] == df.shape[1] + ): + columns = index + elif isinstance(index, str) and index in data.obs.keys(): + index = pd.Categorical(data.obs[index]).categories + if columns is None and s_key == "layers": + columns = data.var_names + elif isinstance(columns, str) and columns in data.obs.keys(): + columns = pd.Categorical(data.obs[columns]).categories + elif isinstance(data, pd.DataFrame): + if isinstance(keys, str) and "*" in keys: + keys, keys_split = keys.split("*") + keys = [k for k in data.columns if keys in k and keys_split in k] + df = data[keys] if keys is not None else data + else: + df = data + + if issparse(df): + df = np.array(df.A) + if columns is None and hasattr(df, "names"): + columns = df.names + + df = pd.DataFrame(df, index=index, columns=columns) + + if dropna: + df.replace("", np.nan, inplace=True) + how = dropna if isinstance(dropna, str) else "any" if dropna is True else "all" + df.dropna(how=how, axis=0, inplace=True) + df.dropna(how=how, axis=1, inplace=True) + + if sort_values: + sort_by = ( + sort_values + if isinstance(sort_values, str) and sort_values in df.columns + else df.columns[0] + ) + df = df.sort_values(by=sort_by, ascending=False) + + if hasattr(data, "var_names"): + if df.index[0] in data.var_names: + df.var_names = df.index + elif df.columns[0] in data.var_names: + df.var_names = df.columns + if hasattr(data, "obs_names"): + if df.index[0] in data.obs_names: + df.obs_names = df.index + elif df.columns[0] in data.obs_names: + df.obs_names = df.columns + + return df + + +# TODO: Generalize to arbitrary modality +def get_initial_size( + adata: AnnData, layer: Optional[str] = None, by_total_size: bool = False +) -> Optional[ndarray]: + """Get initial counts per observation of a layer. + + Arguments + --------- + adata + Annotated data matrix. + layer + Name of layer for which to retrieve initial size. + by_total_size + Whether or not to return the combined initial size of the spliced and unspliced + layers. + + Returns + ------- + np.ndarray + Initial counts per observation in the specified layer. + """ + + if by_total_size: + sizes = [ + adata.obs[f"initial_size_{layer}"] + for layer in {"spliced", "unspliced"} + if f"initial_size_{layer}" in adata.obs.keys() + ] + return np.sum(sizes, axis=0) + elif layer in adata.layers.keys(): + return ( + np.array(adata.obs[f"initial_size_{layer}"]) + if f"initial_size_{layer}" in adata.obs.keys() + else get_size(adata, layer) + ) + elif layer is None or layer == "X": + return ( + np.array(adata.obs["initial_size"]) + if "initial_size" in adata.obs.keys() + else get_size(adata) + ) + else: + return None + + +def get_modality(adata: AnnData, modality: str) -> Union[ndarray, spmatrix]: + """Extract data of one modality. + + Arguments + --------- + adata + Annotated data to extract modality from. + modality + Modality for which data is needed. + + Returns + ------- + Union[ndarray, spmatrix] + Retrieved modality from :class:`~anndata.AnnData` object. + """ + + if modality == "X": + return adata.X + elif modality in adata.layers.keys(): + return adata.layers[modality] + elif modality in adata.obsm.keys(): + if isinstance(adata.obsm[modality], DataFrame): + return adata.obsm[modality].values + else: + return adata.obsm[modality] + + +# TODO: Generalize to arbitray modality +def get_size(adata: AnnData, layer: Optional[str] = None) -> ndarray: + """Get counts per observation in a layer. + + Arguments + --------- + adata + Annotated data matrix. + layer + Name of later for which to retrieve initial size. + + Returns + ------- + np.ndarray + Initial counts per observation in the specified layer. + """ + + X = adata.X if layer is None else adata.layers[layer] + return sum(X, axis=1) + + +def make_dense( + adata: AnnData, modalities: Union[List[str], str], inplace: bool = True +) -> Optional[AnnData]: + """Densify sparse AnnData entry. + + Arguments + --------- + adata + Annotated data object. + modality + Modality to make dense. + inplace + Boolean flag to perform operations inplace or not. Defaults to `True`. + + Returns + ------- + Optional[AnnData] + Copy of annotated data `adata` if `inplace=True` with dense modalities. + """ + + if not inplace: + adata = adata.copy() + + if isinstance(modalities, str): + modalities = [modalities] + + # Densify modalities + for modality in modalities: + count_data = get_modality(adata=adata, modality=modality) + if issparse(count_data): + set_modality(adata=adata, modality=modality, new_value=count_data.A) + + return adata if not inplace else None + + +# TODO: Allow choosing format of sparse matrix i.e., csr, csc, ... +def make_sparse( + adata: AnnData, modalities: Union[List[str], str], inplace: bool = True +) -> Optional[AnnData]: + """Make AnnData entry sparse. + + Arguments + --------- + adata + Annotated data object. + modality + Modality to make sparse. + inplace + Boolean flag to perform operations inplace or not. Defaults to `True`. + + Returns + ------- + Optional[AnnData] + Copy of annotated data `adata` with sparse modalities if `inplace=True`. + """ + + if not inplace: + adata = adata.copy() + + if isinstance(modalities, str): + modalities = [modalities] + + # Make modalities sparse + for modality in modalities: + count_data = get_modality(adata=adata, modality=modality) + if modality == "X": + logg.warn("Making `X` sparse is not supported.") + elif not issparse(count_data): + set_modality( + adata=adata, modality=modality, new_value=csr_matrix(count_data) + ) + + return adata if not inplace else None + + +def merge(adata: AnnData, ldata: AnnData, copy: bool = True) -> Optional[AnnData]: + """Merge two annotated data matrices. + + Arguments + --------- + adata + Annotated data matrix (reference data set). + ldata + Annotated data matrix (to be merged into adata). + copy + Boolean flag to manipulate original AnnData or a copy of it. + + Returns + ------- + Optional[:class:`anndata.AnnData`] + Returns a :class:`~anndata.AnnData` object + """ + + adata.var_names_make_unique() + ldata.var_names_make_unique() + + if ( + "spliced" in ldata.layers.keys() + and "initial_size_spliced" not in ldata.obs.keys() + ): + set_initial_size(ldata) + elif ( + "spliced" in adata.layers.keys() + and "initial_size_spliced" not in adata.obs.keys() + ): + set_initial_size(adata) + + common_obs = pd.unique(adata.obs_names.intersection(ldata.obs_names)) + common_vars = pd.unique(adata.var_names.intersection(ldata.var_names)) + + if len(common_obs) == 0: + clean_obs_names(adata) + clean_obs_names(ldata) + common_obs = adata.obs_names.intersection(ldata.obs_names) + + if copy: + _adata = adata[common_obs].copy() + _ldata = ldata[common_obs].copy() + else: + adata._inplace_subset_obs(common_obs) + _adata, _ldata = adata, ldata[common_obs].copy() + + _adata.var_names_make_unique() + _ldata.var_names_make_unique() + + same_vars = len(_adata.var_names) == len(_ldata.var_names) and np.all( + _adata.var_names == _ldata.var_names + ) + join_vars = len(common_vars) > 0 + + if join_vars and not same_vars: + _adata._inplace_subset_var(common_vars) + _ldata._inplace_subset_var(common_vars) + + for attr in _ldata.obs.keys(): + if attr not in _adata.obs.keys(): + _adata.obs[attr] = _ldata.obs[attr] + for attr in _ldata.obsm.keys(): + if attr not in _adata.obsm.keys(): + _adata.obsm[attr] = _ldata.obsm[attr] + for attr in _ldata.uns.keys(): + if attr not in _adata.uns.keys(): + _adata.uns[attr] = _ldata.uns[attr] + if join_vars: + for attr in _ldata.layers.keys(): + if attr not in _adata.layers.keys(): + _adata.layers[attr] = _ldata.layers[attr] + + if _adata.shape[1] == _ldata.shape[1]: + same_vars = len(_adata.var_names) == len(_ldata.var_names) and np.all( + _adata.var_names == _ldata.var_names + ) + if same_vars: + for attr in _ldata.var.keys(): + if attr not in _adata.var.keys(): + _adata.var[attr] = _ldata.var[attr] + for attr in _ldata.varm.keys(): + if attr not in _adata.varm.keys(): + _adata.varm[attr] = _ldata.varm[attr] + else: + raise ValueError("Variable names are not identical.") + + return _adata if copy else None + + +def obs_df(adata: AnnData, keys: List[str], layer: Optional[str] = None) -> DataFrame: + """Extract layer as Pandas DataFrame indexed by observation. + + Arguments + --------- + adata + Annotated data matrix (reference data set). + keys + Variables for which to extract data. + layer + Name of layer to turn into a Pandas DataFrame. + + Returns + ------- + DataFrame + DataFrame indexed by observations. Columns correspond to variables of specified + layer. + """ + + lookup_keys = [k for k in keys if k in adata.var_names] + if len(lookup_keys) < len(keys): + logg.warn( + f"Keys {[k for k in keys if k not in adata.var_names]} " + f"were not found in `adata.var_names`." + ) + + df = pd.DataFrame(index=adata.obs_names) + for lookup_key in lookup_keys: + df[lookup_key] = adata.obs_vector(lookup_key, layer=layer) + return df + + +# TODO: Generalize to arbitrary modality +def set_initial_size(adata: AnnData, layers: Optional[str] = None) -> None: + """Set current counts per observation of a layer as its initial size. + + The initial size is only set if it does not already exist. + + Arguments + --------- + adata + Annotated data matrix. + layers + Name of layers for which to calculate initial size. + + Returns + ------- + None + """ + + if layers is None: + layers = ["spliced", "unspliced"] + verify_dtypes(adata) + layers = [ + layer + for layer in layers + if layer in adata.layers.keys() + and f"initial_size_{layer}" not in adata.obs.keys() + ] + for layer in layers: + adata.obs[f"initial_size_{layer}"] = get_size(adata, layer) + if "initial_size" not in adata.obs.keys(): + adata.obs["initial_size"] = get_size(adata) + + +def set_modality( + adata: AnnData, + new_value: Union[ndarray, spmatrix, DataFrame], + modality: Optional[str] = None, + inplace: bool = True, +) -> Optional[AnnData]: + """Set modality of annotated data object to new value. + + Arguments + --------- + adata + Annotated data object. + new_value + New value of modality. + modality + Modality to overwrite with new value. Defaults to `None`. + inplace + Boolean flag to indicate whether setting of modality should be inplace or + not. Defaults to `True`. + + Returns + ------- + Optional[AnnData] + Copy of annotated data `adata` with updated modality if `inplace=True`. + """ + + if not inplace: + adata = adata.copy() + + if (modality == "X") or (modality is None): + adata.X = new_value + elif modality in adata.layers.keys(): + adata.layers[modality] = new_value + elif modality in adata.obsm.keys(): + adata.obsm[modality] = new_value + + if not inplace: + return adata + + +def show_proportions( + adata: AnnData, layers: Optional[str] = None, use_raw: bool = True +) -> None: + """Proportions of abundances of modalities in layers. + + The proportions are printed. + + Arguments + --------- + adata + Annotated data matrix. + layers + Layers to consider. + use_raw + Use initial sizes, i.e., raw data, to determine proportions. + + Returns + ------- + None + """ + + if layers is None: + layers = ["spliced", "unspliced", "ambigious"] + layers_keys = [key for key in layers if key in adata.layers.keys()] + counts_layers = [sum(adata.layers[key], axis=1) for key in layers_keys] + if use_raw: + size_key, obs = "initial_size_", adata.obs + counts_layers = [ + obs[size_key + layer] if size_key + layer in obs.keys() else c + for layer, c in zip(layers_keys, counts_layers) + ] + + counts_per_cell_sum = np.sum(counts_layers, 0) + counts_per_cell_sum += counts_per_cell_sum == 0 + + mean_abundances = [ + np.mean(counts_per_cell / counts_per_cell_sum) + for counts_per_cell in counts_layers + ] + + print(f"Abundance of {layers_keys}: {np.round(mean_abundances, 2)}") + + +def var_df(adata: AnnData, keys: List[str], layer: Optional[str] = None): + """Extract layer as Pandas DataFrame indexed by features. + + Arguments + --------- + adata + Annotated data matrix (reference data set). + keys + Observations for which to extract data. + layer + Name of layer to turn into a Pandas DataFrame. + + Returns + ------- + DataFrame + DataFrame indexed by features. Columns correspond to observations of specified + layer. + """ + + lookup_keys = [k for k in keys if k in adata.obs_names] + if len(lookup_keys) < len(keys): + logg.warn( + f"Keys {[k for k in keys if k not in adata.obs_names]} " + f"were not found in `adata.obs_names`." + ) + + df = pd.DataFrame(index=adata.var_names) + for lookup_key in lookup_keys: + df[lookup_key] = adata.var_vector(lookup_key, layer=layer) + return df + + +# TODO: Find better function name +def verify_dtypes(adata: AnnData) -> None: + """Verify that AnnData object is not corrupted. + + Arguments + --------- + adata + Annotated data matrix to check. + + Returns + ------- + None + """ + + try: + _ = adata[:, 0] + except Exception: + uns = adata.uns + adata.uns = {} + try: + _ = adata[:, 0] + logg.warn( + "Safely deleted unstructured annotations (adata.uns), \n" + "as these do not comply with permissible anndata datatypes." + ) + except Exception: + logg.warn( + "The data might be corrupted. Please verify all annotation datatypes." + ) + adata.uns = uns diff --git a/scvelo/core/_arithmetic.py b/scvelo/core/_arithmetic.py new file mode 100644 index 00000000..037de5f2 --- /dev/null +++ b/scvelo/core/_arithmetic.py @@ -0,0 +1,103 @@ +import warnings +from typing import Optional, Union + +import numpy as np +from numpy import ndarray +from scipy.sparse import issparse, spmatrix + + +def clipped_log(x: ndarray, lb: float = 0, ub: float = 1, eps: float = 1e-6) -> ndarray: + """Logarithmize between [lb + epsilon, ub - epsilon]. + + Arguments + --------- + x + Array to invert. + lb + Lower bound of interval to which array entries are clipped. + ub + Upper bound of interval to which array entries are clipped. + eps + Offset of boundaries of clipping interval. + + Returns + ------- + ndarray + Logarithm of clipped array. + """ + + return np.log(np.clip(x, lb + eps, ub - eps)) + + +def invert(x: ndarray) -> ndarray: + """Invert array and set infinity to NaN. + + Arguments + --------- + x + Array to invert. + + Returns + ------- + ndarray + Inverted array. + """ + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_inv = 1 / x * (x != 0) + return x_inv + + +def prod_sum( + a1: Union[ndarray, spmatrix], a2: Union[ndarray, spmatrix], axis: Optional[int] +) -> ndarray: + """Take sum of product of two arrays along given axis. + + Arguments + --------- + a1 + First array. + a2 + Second array. + axis + Axis along which to sum elements. If `None`, all elements will be summed. + Defaults to `None`. + + Returns + ------- + ndarray + Sum of product of arrays along given axis. + """ + + if issparse(a1): + return a1.multiply(a2).sum(axis=axis).A1 + elif axis == 0: + return np.einsum("ij, ij -> j", a1, a2) if a1.ndim > 1 else (a1 * a2).sum() + elif axis == 1: + return np.einsum("ij, ij -> i", a1, a2) if a1.ndim > 1 else (a1 * a2).sum() + + +def sum(a: Union[ndarray, spmatrix], axis: Optional[int] = None) -> ndarray: + """Sum array elements over a given axis. + + Arguments + --------- + a + Elements to sum. + axis + Axis along which to sum elements. If `None`, all elements will be summed. + Defaults to `None`. + + Returns + ------- + ndarray + Sum of array along given axis. + """ + + if a.ndim == 1: + axis = 0 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return a.sum(axis=axis).A1 if issparse(a) else a.sum(axis=axis) diff --git a/scvelo/core/_base.py b/scvelo/core/_base.py new file mode 100644 index 00000000..4990eaaa --- /dev/null +++ b/scvelo/core/_base.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod +from typing import Dict, Tuple, Union + +from numpy import ndarray + + +class DynamicsBase(ABC): + @abstractmethod + def get_solution( + self, t: ndarray, stacked: True, with_keys: bool = False + ) -> Union[Dict, Tuple[ndarray], ndarray]: + """Calculate solution of dynamics. + + Arguments + --------- + t + Time steps at which to evaluate solution. + stacked + Whether to stack states or return them individually. Defaults to `True`. + with_keys + Whether to return solution labelled by variables in form of a dictionary. + Defaults to `False`. + + Returns + ------- + Union[Dict, Tuple[ndarray], ndarray] + Solution of system. If `with_keys=True`, the solution is returned in form of + a dictionary with variables as keys. Otherwise, the solution is given as + a `numpy.ndarray` of form `(n_steps, n_vars)`. + """ + + return + + @abstractmethod + def get_steady_states( + self, stacked: True, with_keys: False + ) -> Union[Dict[str, ndarray], Tuple[ndarray], ndarray]: + """Return steady state of system. + + Arguments + --------- + stacked + Whether to stack states or return them individually. Defaults to `True`. + with_keys + Whether to return solution labelled by variables in form of a dictionary. + Defaults to `False`. + + Returns + ------- + Union[Dict[str, ndarray], Tuple[ndarray], ndarray] + Steady state of system. + """ + + return diff --git a/scvelo/core/_linear_models.py b/scvelo/core/_linear_models.py new file mode 100644 index 00000000..89796761 --- /dev/null +++ b/scvelo/core/_linear_models.py @@ -0,0 +1,150 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +from numpy import ndarray +from scipy.sparse import csr_matrix, issparse + +from ._arithmetic import prod_sum, sum + + +class LinearRegression: + """Extreme quantile and constraint least square linear regression. + + Arguments + --------- + percentile + Percentile of data on which linear regression line is fit. If `None`, all data + is used, if a single value is given, it is interpreted as the upper quantile. + Defaults to `None`. + fit_intercept + Whether to calculate the intercept for model. Defaults to `False`. + positive_intercept + Whether the intercept it constraint to positive values. Only plays a role when + `fit_intercept=True`. Defaults to `True`. + constrain_ratio + Ratio to which coefficients are clipped. If `None`, the coefficients are not + constraint. Defaults to `None`. + + Attributes + ---------- + coef_ + Estimated coefficients of the linear regression line. + + intercept_ + Fitted intercept of linear model. Set to `0.0` if `fit_intercept=False`. + + """ + + def __init__( + self, + percentile: Optional[Union[Tuple, int, float]] = None, + fit_intercept: bool = False, + positive_intercept: bool = True, + constrain_ratio: Optional[Union[Tuple, float]] = None, + ): + if not fit_intercept and isinstance(percentile, (list, tuple)): + self.percentile = percentile[1] + else: + self.percentile = percentile + self.fit_intercept = fit_intercept + self.positive_intercept = positive_intercept + + if constrain_ratio is None: + self.constrain_ratio = [-np.inf, np.inf] + elif len(constrain_ratio) == 1: + self.constrain_ratio = [-np.inf, constrain_ratio] + else: + self.constrain_ratio = constrain_ratio + + def _trim_data(self, data: List) -> List: + """Trim data to extreme values. + + Arguments + --------- + data + Data to be trimmed to extreme quantiles. + + Returns + ------- + List + Number of non-trivial entries per column and trimmed data. + """ + + if not isinstance(data, List): + data = [data] + + data = np.array( + [data_mat.A if issparse(data_mat) else data_mat for data_mat in data] + ) + + # TODO: Add explanatory comment + normalized_data = np.sum( + data / data.max(axis=1, keepdims=True).clip(1e-3, None), axis=0 + ) + + bound = np.percentile(normalized_data, self.percentile, axis=0) + + if bound.ndim == 1: + trimmer = csr_matrix(normalized_data >= bound).astype(bool) + else: + trimmer = csr_matrix( + (normalized_data <= bound[0]) | (normalized_data >= bound[1]) + ).astype(bool) + + return [trimmer.getnnz(axis=0)] + [ + trimmer.multiply(data_mat).tocsr() for data_mat in data + ] + + def fit(self, x: ndarray, y: ndarray): + """Fit linear model per column. + + Arguments + --------- + x + Training data of shape `(n_obs, n_vars)`. + y + Target values of shape `(n_obs, n_vars)`. + + Returns + ------- + self + Returns an instance of self. + """ + + n_obs = x.shape[0] + + if self.percentile is not None: + n_obs, x, y = self._trim_data(data=[x, y]) + + _xx = prod_sum(x, x, axis=0) + _xy = prod_sum(x, y, axis=0) + + if self.fit_intercept: + _x = sum(x, axis=0) / n_obs + _y = sum(y, axis=0) / n_obs + self.coef_ = (_xy / n_obs - _x * _y) / (_xx / n_obs - _x ** 2) + self.intercept_ = _y - self.coef_ * _x + + if self.positive_intercept: + idx = self.intercept_ < 0 + if self.coef_.ndim > 0: + self.coef_[idx] = _xy[idx] / _xx[idx] + else: + self.coef_ = _xy / _xx + self.intercept_ = np.clip(self.intercept_, 0, None) + else: + self.coef_ = _xy / _xx + self.intercept_ = np.zeros(x.shape[1]) if x.ndim > 1 else 0 + + if not np.isscalar(self.coef_): + self.coef_[np.isnan(self.coef_)] = 0 + self.intercept_[np.isnan(self.intercept_)] = 0 + else: + if np.isnan(self.coef_): + self.coef_ = 0 + if np.isnan(self.intercept_): + self.intercept_ = 0 + + self.coef_ = np.clip(self.coef_, *self.constrain_ratio) + + return self diff --git a/scvelo/core/_metrics.py b/scvelo/core/_metrics.py new file mode 100644 index 00000000..539e4489 --- /dev/null +++ b/scvelo/core/_metrics.py @@ -0,0 +1,32 @@ +from typing import Union + +import numpy as np +from numpy import ndarray +from scipy.sparse import issparse, spmatrix + + +# TODO: Add case `axis == None` +def l2_norm(x: Union[ndarray, spmatrix], axis: int = 1) -> Union[float, ndarray]: + """Calculate l2 norm along a given axis. + + Arguments + --------- + x + Array to calculate l2 norm of. + axis + Axis along which to calculate l2 norm. + + Returns + ------- + Union[float, ndarray] + L2 norm along a given axis. + """ + + if issparse(x): + return np.sqrt(x.multiply(x).sum(axis=axis).A1) + elif x.ndim == 1: + return np.sqrt(np.einsum("i, i -> ", x, x)) + elif axis == 0: + return np.sqrt(np.einsum("ij, ij -> j", x, x)) + elif axis == 1: + return np.sqrt(np.einsum("ij, ij -> i", x, x)) diff --git a/scvelo/core/_models.py b/scvelo/core/_models.py new file mode 100644 index 00000000..c38c6688 --- /dev/null +++ b/scvelo/core/_models.py @@ -0,0 +1,143 @@ +from typing import Dict, List, Tuple, Union + +import numpy as np +from numpy import ndarray + +from ._arithmetic import invert +from ._base import DynamicsBase + + +# TODO: Improve parameter names: alpha -> transcription_rate; beta -> splicing_rate; +# gamma -> degradation_rate +# TODO: Handle cases beta = 0, gamma == 0, beta == gamma +class SplicingDynamics(DynamicsBase): + """Splicing dynamics. + + Arguments + --------- + alpha + Transcription rate. + beta + Translation rate. + gamma + Splicing degradation rate. + initial_state + Initial state of system. Defaults to `[0, 0]`. + + Attributes + ---------- + alpha + Transcription rate. + beta + Translation rate. + gamma + Splicing degradation rate. + initial_state + Initial state of system. Defaults to `[0, 0]`. + u0 + Initial abundance of unspliced RNA. + s0 + Initial abundance of spliced RNA. + + """ + + def __init__( + self, + alpha: float, + beta: float, + gamma: float, + initial_state: Union[List, ndarray] = [0, 0], + ): + + self.alpha = alpha + self.beta = beta + self.gamma = gamma + + self.initial_state = initial_state + + @property + def initial_state(self): + return self._initial_state + + @initial_state.setter + def initial_state(self, val): + if isinstance(val, list) or (isinstance(val, ndarray) and (val.ndim == 1)): + self.u0 = val[0] + self.s0 = val[1] + else: + self.u0 = val[:, 0] + self.s0 = val[:, 1] + self._initial_state = val + + def get_solution( + self, t: ndarray, stacked: bool = True, with_keys: bool = False + ) -> Union[Dict, ndarray]: + """Calculate solution of dynamics. + + Arguments + --------- + t + Time steps at which to evaluate solution. + stacked + Whether to stack states or return them individually. Defaults to `True`. + with_keys + Whether to return solution labelled by variables in form of a dictionary. + Defaults to `False`. + + Returns + ------- + Union[Dict, ndarray] + Solution of system. If `with_keys=True`, the solution is returned in form of + a dictionary with variables as keys. Otherwise, the solution is given as + a `numpy.ndarray` of form `(n_steps, 2)`. + """ + + expu = np.exp(-self.beta * t) + exps = np.exp(-self.gamma * t) + + unspliced = self.u0 * expu + self.alpha / self.beta * (1 - expu) + c = (self.alpha - self.u0 * self.beta) * invert(self.gamma - self.beta) + spliced = ( + self.s0 * exps + self.alpha / self.gamma * (1 - exps) + c * (exps - expu) + ) + + if with_keys: + return {"u": unspliced, "s": spliced} + elif not stacked: + return unspliced, spliced + else: + if isinstance(t, np.ndarray) and t.ndim == 2: + return np.stack([unspliced, spliced], axis=2) + else: + return np.column_stack([unspliced, spliced]) + + # TODO: Handle cases `beta = 0`, `gamma = 0` + def get_steady_states( + self, stacked: bool = True, with_keys: bool = False + ) -> Union[Dict[str, ndarray], Tuple[ndarray], ndarray]: + """Return steady state of system. + + Arguments + --------- + stacked + Whether to stack states or return them individually. Defaults to `True`. + with_keys + Whether to return solution labelled by variables in form of a dictionary. + Defaults to `False`. + + Returns + ------- + Union[Dict[str, ndarray], Tuple[ndarray], ndarray] + Steady state of system. + """ + + if (self.beta > 0) and (self.gamma > 0): + unspliced = self.alpha / self.beta + spliced = self.alpha / self.gamma + + if with_keys: + return {"u": unspliced, "s": spliced} + elif not stacked: + return unspliced, spliced + else: + return np.array([unspliced, spliced]) diff --git a/scvelo/core/_parallelize.py b/scvelo/core/_parallelize.py new file mode 100644 index 00000000..85b7120f --- /dev/null +++ b/scvelo/core/_parallelize.py @@ -0,0 +1,169 @@ +import os +from multiprocessing import Manager +from threading import Thread +from typing import Any, Callable, Optional, Sequence, Union + +from joblib import delayed, Parallel + +import numpy as np +from scipy.sparse import issparse, spmatrix + +from scvelo import logging as logg + +_msg_shown = False + + +def get_n_jobs(n_jobs): + if n_jobs is None or (n_jobs < 0 and os.cpu_count() + 1 + n_jobs <= 0): + return 1 + elif n_jobs > os.cpu_count(): + return os.cpu_count() + elif n_jobs < 0: + return os.cpu_count() + 1 + n_jobs + else: + return n_jobs + + +def parallelize( + callback: Callable[[Any], Any], + collection: Union[spmatrix, Sequence[Any]], + n_jobs: Optional[int] = None, + n_split: Optional[int] = None, + unit: str = "", + as_array: bool = True, + use_ixs: bool = False, + backend: str = "loky", + extractor: Optional[Callable[[Any], Any]] = None, + show_progress_bar: bool = True, +) -> Union[np.ndarray, Any]: + """ + Parallelize function call over a collection of elements. + + Parameters + ---------- + callback + Function to parallelize. + collection + Sequence of items which to chunkify. + n_jobs + Number of parallel jobs. + n_split + Split :paramref:`collection` into :paramref:`n_split` chunks. + If `None`, split into :paramref:`n_jobs` chunks. + unit + Unit of the progress bar. + as_array + Whether to convert the results not :class:`numpy.ndarray`. + use_ixs + Whether to pass indices to the callback. + backend + Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid + options. + extractor + Function to apply to the result after all jobs have finished. + show_progress_bar + Whether to show a progress bar. + + Returns + ------- + :class:`numpy.ndarray` + Result depending on :paramref:`extractor` and :paramref:`as_array`. + """ + + if show_progress_bar: + try: + try: + from tqdm.notebook import tqdm + except ImportError: + from tqdm import tqdm_notebook as tqdm + import ipywidgets # noqa + except ImportError: + global _msg_shown + tqdm = None + + if not _msg_shown: + logg.warn( + "Unable to create progress bar. " + "Consider installing `tqdm` as `pip install tqdm` " + "and `ipywidgets` as `pip install ipywidgets`,\n" + "or disable the progress bar using `show_progress_bar=False`." + ) + _msg_shown = True + else: + tqdm = None + + def update(pbar, queue, n_total): + n_finished = 0 + while n_finished < n_total: + try: + res = queue.get() + except EOFError as e: + if not n_finished != n_total: + raise RuntimeError( + f"Finished only `{n_finished} out of `{n_total}` tasks.`" + ) from e + break + assert res in (None, (1, None), 1) # (None, 1) means only 1 job + if res == (1, None): + n_finished += 1 + if pbar is not None: + pbar.update() + elif res is None: + n_finished += 1 + elif pbar is not None: + pbar.update() + + if pbar is not None: + pbar.close() + + def wrapper(*args, **kwargs): + if pass_queue and show_progress_bar: + pbar = None if tqdm is None else tqdm(total=col_len, unit=unit) + queue = Manager().Queue() + thread = Thread(target=update, args=(pbar, queue, len(collections))) + thread.start() + else: + pbar, queue, thread = None, None, None + + res = Parallel(n_jobs=n_jobs, backend=backend)( + delayed(callback)( + *((i, cs) if use_ixs else (cs,)), + *args, + **kwargs, + queue=queue, + ) + for i, cs in enumerate(collections) + ) + + res = np.array(res) if as_array else res + if thread is not None: + thread.join() + + return res if extractor is None else extractor(res) + + col_len = collection.shape[0] if issparse(collection) else len(collection) + + if n_split is None: + n_split = get_n_jobs(n_jobs=n_jobs) + + if issparse(collection): + if n_split == collection.shape[0]: + collections = [collection[[ix], :] for ix in range(collection.shape[0])] + else: + step = collection.shape[0] // n_split + + ixs = [ + np.arange(i * step, min((i + 1) * step, collection.shape[0])) + for i in range(n_split) + ] + ixs[-1] = np.append( + ixs[-1], np.arange(ixs[-1][-1] + 1, collection.shape[0]) + ) + + collections = [collection[ix, :] for ix in filter(len, ixs)] + else: + collections = list(filter(len, np.array_split(collection, n_split))) + + pass_queue = not hasattr(callback, "py_func") # we'd be inside a numba function + + return wrapper diff --git a/scvelo/core/tests/__init__.py b/scvelo/core/tests/__init__.py new file mode 100644 index 00000000..6d4c3dfd --- /dev/null +++ b/scvelo/core/tests/__init__.py @@ -0,0 +1,3 @@ +from .test_base import get_adata, TestBase + +__all__ = ["get_adata", "TestBase"] diff --git a/scvelo/core/tests/test_anndata.py b/scvelo/core/tests/test_anndata.py new file mode 100644 index 00000000..b73cc941 --- /dev/null +++ b/scvelo/core/tests/test_anndata.py @@ -0,0 +1,136 @@ +import hypothesis.strategies as st +from hypothesis import given + +import numpy as np +from numpy.testing import assert_array_equal +from scipy.sparse import issparse + +from anndata import AnnData + +from scvelo.core import get_modality, make_dense, make_sparse, set_modality +from .test_base import get_adata, TestBase + + +class TestGetModality(TestBase): + @given(adata=get_adata()) + def test_get_modality(self, adata: AnnData): + modality_to_get = self._subset_modalities(adata, 1)[0] + modality_retrieved = get_modality(adata=adata, modality=modality_to_get) + + if modality_to_get == "X": + assert_array_equal(adata.X, modality_retrieved) + elif modality_to_get in adata.layers: + assert_array_equal(adata.layers[modality_to_get], modality_retrieved) + else: + assert_array_equal(adata.obsm[modality_to_get], modality_retrieved) + + +class TestMakeDense(TestBase): + @given( + adata=get_adata(sparse_entries=True), + inplace=st.booleans(), + n_modalities=st.integers(min_value=0), + ) + def test_make_dense(self, adata: AnnData, inplace: bool, n_modalities: int): + modalities_to_densify = self._subset_modalities(adata, n_modalities) + + returned_adata = make_dense( + adata=adata, modalities=modalities_to_densify, inplace=inplace + ) + + if inplace: + assert returned_adata is None + assert np.all( + [ + not issparse(get_modality(adata=adata, modality=modality)) + for modality in modalities_to_densify + ] + ) + else: + assert isinstance(returned_adata, AnnData) + assert np.all( + [ + not issparse(get_modality(adata=returned_adata, modality=modality)) + for modality in modalities_to_densify + ] + ) + assert np.all( + [ + issparse(get_modality(adata=adata, modality=modality)) + for modality in modalities_to_densify + ] + ) + + +class TestMakeSparse(TestBase): + @given( + adata=get_adata(), + inplace=st.booleans(), + n_modalities=st.integers(min_value=0), + ) + def test_make_sparse(self, adata: AnnData, inplace: bool, n_modalities: int): + modalities_to_make_sparse = self._subset_modalities(adata, n_modalities) + + returned_adata = make_sparse( + adata=adata, modalities=modalities_to_make_sparse, inplace=inplace + ) + + if inplace: + assert returned_adata is None + assert np.all( + [ + issparse(get_modality(adata=adata, modality=modality)) + for modality in modalities_to_make_sparse + if modality != "X" + ] + ) + else: + assert isinstance(returned_adata, AnnData) + assert np.all( + [ + issparse(get_modality(adata=returned_adata, modality=modality)) + for modality in modalities_to_make_sparse + if modality != "X" + ] + ) + assert np.all( + [ + not issparse(get_modality(adata=adata, modality=modality)) + for modality in modalities_to_make_sparse + if modality != "X" + ] + ) + + +class TestSetModality(TestBase): + @given(adata=get_adata(), inplace=st.booleans()) + def test_set_modality(self, adata: AnnData, inplace: bool): + modality_to_set = self._subset_modalities(adata, 1)[0] + + if (modality_to_set == "X") or (modality_to_set in adata.layers): + new_value = np.random.randn(adata.n_obs, adata.n_vars) + else: + new_value = np.random.randn( + adata.n_obs, np.random.randint(low=1, high=10000) + ) + + returned_adata = set_modality( + adata=adata, new_value=new_value, modality=modality_to_set, inplace=inplace + ) + + if inplace: + assert returned_adata is None + if modality_to_set == "X": + assert_array_equal(adata.X, new_value) + elif modality_to_set in adata.layers: + assert_array_equal(adata.layers[modality_to_set], new_value) + else: + assert_array_equal(adata.obsm[modality_to_set], new_value) + else: + assert isinstance(returned_adata, AnnData) + if modality_to_set == "X": + assert_array_equal(returned_adata.X, new_value) + elif modality_to_set in adata.layers: + assert_array_equal(returned_adata.layers[modality_to_set], new_value) + else: + assert_array_equal(returned_adata.obsm[modality_to_set], new_value) diff --git a/scvelo/core/tests/test_arithmetic.py b/scvelo/core/tests/test_arithmetic.py new file mode 100644 index 00000000..d8f12092 --- /dev/null +++ b/scvelo/core/tests/test_arithmetic.py @@ -0,0 +1,197 @@ +from typing import List + +from hypothesis import given +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +import numpy as np +from numpy import ndarray +from numpy.testing import assert_almost_equal, assert_array_equal + +from scvelo.core import clipped_log, invert, prod_sum, sum + + +class TestClippedLog: + @given( + a=arrays( + float, + shape=st.integers(min_value=1, max_value=100), + elements=st.floats( + min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False + ), + ), + bounds=st.lists( + st.floats( + min_value=0, max_value=100, allow_infinity=False, allow_nan=False + ), + min_size=2, + max_size=2, + unique=True, + ), + eps=st.floats( + min_value=1e-6, max_value=1, allow_infinity=False, allow_nan=False + ), + ) + def test_flat_arrays(self, a: ndarray, bounds: List[float], eps: float): + lb = min(bounds) + ub = max(bounds) + 2 * eps + + a_logged = clipped_log(a, lb=lb, ub=ub, eps=eps) + + assert a_logged.shape == a.shape + if (a <= lb).any(): + assert_almost_equal(np.abs(a_logged - np.log(lb + eps)).min(), 0) + else: + assert (a_logged >= np.log(lb + eps)).all() + if (a >= ub).any(): + assert_almost_equal(np.abs(a_logged - np.log(ub - eps)).min(), 0) + else: + assert (a_logged <= np.log(ub - eps)).all() + + @given( + a=arrays( + float, + shape=st.tuples( + st.integers(min_value=1, max_value=100), + st.integers(min_value=1, max_value=100), + ), + elements=st.floats( + min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False + ), + ), + bounds=st.lists( + st.floats( + min_value=0, max_value=100, allow_infinity=False, allow_nan=False + ), + min_size=2, + max_size=2, + unique=True, + ), + eps=st.floats( + min_value=1e-6, max_value=1, allow_infinity=False, allow_nan=False + ), + ) + def test_2d_arrays(self, a: ndarray, bounds: List[float], eps: float): + lb = min(bounds) + ub = max(bounds) + 2 * eps + + a_logged = clipped_log(a, lb=lb, ub=ub, eps=eps) + + assert a_logged.shape == a.shape + if (a <= lb).any(): + assert_almost_equal(np.abs(a_logged - np.log(lb + eps)).min(), 0) + else: + assert (a_logged >= np.log(lb + eps)).all() + if (a >= ub).any(): + assert_almost_equal(np.abs(a_logged - np.log(ub - eps)).min(), 0) + else: + assert (a_logged <= np.log(ub - eps)).all() + + +class TestInvert: + @given( + a=arrays( + float, + shape=st.integers(min_value=1, max_value=100), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ) + ) + def test_flat_arrays(self, a: ndarray): + a_inv = invert(a) + + if a[a != 0].size == 0: + assert a_inv[a != 0].size == 0 + else: + assert_array_equal(a_inv[a != 0], 1 / a[a != 0]) + + if 0 in a: + assert np.isnan(a_inv[a == 0]).all() + else: + assert set(a_inv[a == 0]) == set() + + @given( + a=arrays( + float, + shape=st.tuples( + st.integers(min_value=1, max_value=100), + st.integers(min_value=1, max_value=100), + ), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ) + ) + def test_2d_arrays(self, a: ndarray): + a_inv = invert(a) + + if a[a != 0].size == 0: + assert a_inv[a != 0].size == 0 + else: + assert_array_equal(a_inv[a != 0], 1 / a[a != 0]) + + if 0 in a: + assert np.isnan(a_inv[a == 0]).all() + else: + assert set(a_inv[a == 0]) == set() + + +# TODO: Extend test to generate sparse inputs as well +# TODO: Make test to generate two different arrays a1, a2 +# TODO: Check why tests fail with assert_almost_equal +class TestProdSum: + @given( + a=arrays( + float, + shape=st.integers(min_value=1, max_value=100), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ), + axis=st.integers(min_value=0, max_value=1), + ) + def test_flat_array(self, a: ndarray, axis: int): + assert np.allclose((a * a).sum(axis=0), prod_sum(a, a, axis=axis)) + + @given( + a=arrays( + float, + shape=st.tuples( + st.integers(min_value=1, max_value=100), + st.integers(min_value=1, max_value=100), + ), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ), + axis=st.integers(min_value=0, max_value=1), + ) + def test_2d_array(self, a: ndarray, axis: int): + assert np.allclose((a * a).sum(axis=axis), prod_sum(a, a, axis=axis)) + + +# TODO: Extend test to generate sparse inputs as well +class TestSum: + @given( + a=arrays( + float, + shape=st.integers(min_value=1, max_value=100), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ), + ) + def test_flat_arrays(self, a: ndarray): + a_summed = sum(a=a, axis=0) + + assert_array_equal(a_summed, a.sum(axis=0)) + + @given( + a=arrays( + float, + shape=st.tuples( + st.integers(min_value=1, max_value=100), + st.integers(min_value=1, max_value=100), + ), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ), + axis=st.integers(min_value=0, max_value=1), + ) + def test_2d_arrays(self, a: ndarray, axis: int): + a_summed = sum(a=a, axis=axis) + + if a.ndim == 1: + axis = 0 + + assert_array_equal(a_summed, a.sum(axis=axis)) diff --git a/scvelo/core/tests/test_base.py b/scvelo/core/tests/test_base.py new file mode 100644 index 00000000..13095a52 --- /dev/null +++ b/scvelo/core/tests/test_base.py @@ -0,0 +1,196 @@ +import random +from typing import List, Optional, Union + +import hypothesis.strategies as st +from hypothesis import given +from hypothesis.extra.numpy import arrays + +import numpy as np +from scipy.sparse import csr_matrix, issparse + +from anndata import AnnData + + +# TODO: Add possibility to generate adata object with floats as counts +@st.composite +def get_adata( + draw, + n_obs: Optional[int] = None, + n_vars: Optional[int] = None, + min_obs: Optional[int] = 1, + max_obs: Optional[int] = 100, + min_vars: Optional[int] = 1, + max_vars: Optional[int] = 100, + layer_keys: Optional[Union[List, str]] = None, + min_layers: Optional[int] = 2, + max_layers: int = 2, + obsm_keys: Optional[Union[List, str]] = None, + min_obsm: Optional[int] = 2, + max_obsm: Optional[int] = 2, + sparse_entries: bool = False, +) -> AnnData: + """Generate an AnnData object. + + The largest possible value of a numerical entry is `1e5`. + + Arguments + --------- + n_obs: + Number of observations. If set to `None`, a random integer between `1` and + `max_obs` will be drawn. Defaults to `None`. + n_vars: + Number of variables. If set to `None`, a random integer between `1` and + `max_vars` will be drawn. Defaults to `None`. + min_obs: + Minimum number of observations. If set to `None`, there is no lower limit. + Defaults to `1`. + max_obs: + Maximum number of observations. If set to `None`, there is no upper limit. + Defaults to `100`. + min_vars: + Minimum number of variables. If set to `None`, there is no lower limit. + Defaults to `1`. + max_vars: + Maximum number of variables. If set to `None`, there is no upper limit. + Defaults to `100`. + layer_keys: + Names of layers. If set to `None`, layers will be named at random. Defaults + to `None`. + min_layers: + Minimum number of layers. Is set to the number of provided layer names if + `layer_keys` is not `None`. Defaults to `2`. + max_layers: Maximum number of layers. Is set to the number of provided layer + names if `layer_keys` is not `None`. Defaults to `2`. + obsm_keys: + Names of multi-dimensional observations annotation. If set to `None`, names + will be generated at random. Defaults to `None`. + min_obsm: + Minimum number of multi-dimensional observations annotation. Is set to the + number of keys if `obsm_keys` is not `None`. Defaults to `2`. + max_obsm: + Maximum number of multi-dimensional observations annotation. Is set to the + number of keys if `obsm_keys` is not `None`. Defaults to `2`. + sparse_entries: + Whether or not to make AnnData entries sparse. + + Returns + ------- + AnnData + Generated :class:`~anndata.AnnData` object. + """ + + if n_obs is None: + n_obs = draw(st.integers(min_value=min_obs, max_value=max_obs)) + if n_vars is None: + n_vars = draw(st.integers(min_value=min_vars, max_value=max_vars)) + + if isinstance(layer_keys, str): + layer_keys = [layer_keys] + if isinstance(obsm_keys, str): + obsm_keys = [obsm_keys] + + if layer_keys is not None: + min_layers = len(layer_keys) + max_layers = len(layer_keys) + if obsm_keys is not None: + min_obsm = len(obsm_keys) + max_obsm = len(obsm_keys) + + X = draw( + arrays( + dtype=int, + elements=st.integers(min_value=0, max_value=1e2), + shape=(n_obs, n_vars), + ) + ) + + layers = draw( + st.dictionaries( + st.text(min_size=1) if layer_keys is None else st.sampled_from(layer_keys), + arrays( + dtype=int, + elements=st.integers(min_value=0, max_value=1e2), + shape=(n_obs, n_vars), + ), + min_size=min_layers, + max_size=max_layers, + ) + ) + + obsm = draw( + st.dictionaries( + st.text(min_size=1) if obsm_keys is None else st.sampled_from(obsm_keys), + arrays( + dtype=int, + elements=st.integers(min_value=0, max_value=1e2), + shape=st.tuples( + st.integers(min_value=n_obs, max_value=n_obs), + st.integers(min_value=min_vars, max_value=max_vars), + ), + ), + min_size=min_obsm, + max_size=max_obsm, + ) + ) + + # Make keys for layers and obsm unique + for key in set(layers.keys()).intersection(obsm.keys()): + layers[f"{key}_"] = layers.pop(key) + + if sparse_entries: + layers = {key: csr_matrix(val) for key, val in layers.items()} + obsm = {key: csr_matrix(val) for key, val in obsm.items()} + return AnnData(X=csr_matrix(X), layers=layers, obsm=obsm) + else: + return AnnData(X=X, layers=layers, obsm=obsm) + + +class TestAdataGeneration: + @given(adata=get_adata()) + def test_default_adata_generation(self, adata: AnnData): + assert type(adata) is AnnData + + @given(adata=get_adata(sparse_entries=True)) + def test_sparse_adata_generation(self, adata: AnnData): + assert type(adata) is AnnData + assert issparse(adata.X) + assert np.all([issparse(adata.layers[layer]) for layer in adata.layers]) + assert np.all([issparse(adata.obsm[name]) for name in adata.obsm]) + + @given( + adata=get_adata( + n_obs=2, n_vars=2, layer_keys=["unspliced", "spliced"], obsm_keys="X_umap" + ) + ) + def test_custom_adata_generation(self, adata: AnnData): + assert adata.X.shape == (2, 2) + assert len(adata.layers) == 2 + assert len(adata.obsm) == 1 + assert set(adata.layers.keys()) == {"unspliced", "spliced"} + assert set(adata.obsm.keys()) == {"X_umap"} + + +class TestBase: + def _subset_modalities( + self, + adata: AnnData, + n_modalities: int, + from_layers: bool = True, + from_obsm: bool = True, + ): + """Subset modalities of an AnnData object.""" + + modalities = ["X"] + if from_layers: + modalities += list(adata.layers.keys()) + if from_obsm: + modalities += list(adata.obsm.keys()) + return random.sample(modalities, min(len(modalities), n_modalities)) + + def _convert_to_float(self, adata: AnnData): + """Convert AnnData entries in `layer` and `obsm` into floats.""" + + for layer in adata.layers: + adata.layers[layer] = adata.layers[layer].astype(float) + for obs in adata.obsm: + adata.obsm[obs] = adata.obsm[obs].astype(float) diff --git a/scvelo/core/tests/test_linear_models.py b/scvelo/core/tests/test_linear_models.py new file mode 100644 index 00000000..49dcca65 --- /dev/null +++ b/scvelo/core/tests/test_linear_models.py @@ -0,0 +1,86 @@ +import pytest +from hypothesis import given +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +import numpy as np +from numpy import ndarray +from numpy.testing import assert_almost_equal, assert_array_equal + +from scvelo.core import LinearRegression + + +class TestLinearRegression: + @given( + x=arrays( + float, + shape=st.integers(min_value=1, max_value=100), + elements=st.floats( + min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False + ), + ), + coef=st.floats( + min_value=-1000, max_value=1000, allow_infinity=False, allow_nan=False + ), + ) + def test_perfect_fit(self, x: ndarray, coef: float): + lr = LinearRegression() + lr.fit(x, x * coef) + + assert lr.intercept_ == 0 + if set(x) != {0}: # fit is only unique if x is non-trivial + assert_almost_equal(lr.coef_, coef) + + @given( + x=arrays( + float, + shape=st.tuples( + st.integers(min_value=1, max_value=100), + st.integers(min_value=1, max_value=100), + ), + elements=st.floats( + min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False + ), + ), + coef=arrays( + float, + shape=100, + elements=st.floats( + min_value=-1000, max_value=1000, allow_infinity=False, allow_nan=False + ), + ), + ) + # TODO: Extend test to use `percentile`. Zero columns (after trimming) make the + # previous implementation of the unit test fail + # TODO: Check why test fails if number of columns is increased to e.g. 1000 (500) + def test_perfect_fit_2d(self, x: ndarray, coef: ndarray): + coef = coef[: x.shape[1]] + lr = LinearRegression() + lr.fit(x, x * coef) + + assert lr.coef_.shape == (x.shape[1],) + assert lr.intercept_.shape == (x.shape[1],) + assert_array_equal(lr.intercept_, np.zeros(x.shape[1])) + if set(x.flatten()) != {0}: # fit is only unique if x is non-trivial + assert_almost_equal(lr.coef_, coef) + + # TODO: Use hypothesis + # TODO: Integrate into `test_perfect_fit_2d` + @pytest.mark.parametrize( + "x, coef, intercept", + [ + (np.array([[0], [1], [2], [3]]), 0, 1), + (np.array([[0], [1], [2], [3]]), 2, 1), + (np.array([[0], [1], [2], [3]]), 2, -1), + ], + ) + def test_perfect_fit_with_intercept( + self, x: ndarray, coef: float, intercept: float + ): + lr = LinearRegression(fit_intercept=True, positive_intercept=False) + lr.fit(x, x * coef + intercept) + + assert lr.coef_.shape == (x.shape[1],) + assert lr.intercept_.shape == (x.shape[1],) + assert_array_equal(lr.intercept_, intercept) + assert_array_equal(lr.coef_, coef) diff --git a/scvelo/core/tests/test_metrics.py b/scvelo/core/tests/test_metrics.py new file mode 100644 index 00000000..81f54070 --- /dev/null +++ b/scvelo/core/tests/test_metrics.py @@ -0,0 +1,24 @@ +from hypothesis import given +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +import numpy as np +from numpy import ndarray + +from scvelo.core import l2_norm + + +# TODO: Extend test to generate sparse inputs as well +@given( + a=arrays( + float, + shape=st.integers(min_value=1, max_value=100), + elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False), + ), + axis=st.integers(min_value=0, max_value=1), +) +def test_l2_norm(a: ndarray, axis: int): + if a.ndim == 1: + np.allclose(np.linalg.norm(a), l2_norm(a, axis=axis)) + else: + np.allclose(np.linalg.norm(a, axis=axis), l2_norm(a, axis=axis)) diff --git a/scvelo/core/tests/test_models.py b/scvelo/core/tests/test_models.py new file mode 100644 index 00000000..bca814fb --- /dev/null +++ b/scvelo/core/tests/test_models.py @@ -0,0 +1,90 @@ +from typing import List + +import pytest +from hypothesis import given +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +import numpy as np +from numpy import ndarray +from scipy.integrate import odeint + +from scvelo.core import SplicingDynamics + + +class TestSplicingDynamics: + @given( + alpha=st.floats(min_value=0, allow_infinity=False), + beta=st.floats(min_value=0, max_value=1, exclude_min=True), + gamma=st.floats(min_value=0, max_value=1, exclude_min=True), + initial_state=st.lists( + st.floats(min_value=0, allow_infinity=False), min_size=2, max_size=2 + ), + t=arrays( + float, + shape=st.integers(min_value=1, max_value=100), + elements=st.floats( + min_value=0, max_value=1e3, allow_infinity=False, allow_nan=False + ), + ), + with_keys=st.booleans(), + ) + def test_output_form( + self, + alpha: float, + beta: float, + gamma: float, + initial_state: List[float], + t: ndarray, + with_keys: bool, + ): + if beta == gamma: + gamma = gamma + 1e-6 + + splicing_dynamics = SplicingDynamics( + alpha=alpha, beta=beta, gamma=gamma, initial_state=initial_state + ) + solution = splicing_dynamics.get_solution(t=t, with_keys=with_keys) + + if not with_keys: + assert type(solution) == ndarray + assert solution.shape == (len(t), 2) + else: + assert len(solution) == 2 + assert type(solution) == dict + assert list(solution.keys()) == ["u", "s"] + assert all([len(var) == len(t) for var in solution.values()]) + + # TODO: Check how / if hypothesis can be used instead. + @pytest.mark.parametrize( + "alpha, beta, gamma, initial_state", + [ + (5, 0.5, 0.4, [0, 1]), + ], + ) + def test_solution(self, alpha, beta, gamma, initial_state): + def model(y, t, alpha, beta, gamma): + dydt = np.zeros(2) + dydt[0] = alpha - beta * y[0] + dydt[1] = beta * y[0] - gamma * y[1] + + return dydt + + t = np.linspace(0, 20, 10000) + splicing_dynamics = SplicingDynamics( + alpha=alpha, beta=beta, gamma=gamma, initial_state=initial_state + ) + exact_solution = splicing_dynamics.get_solution(t=t) + + numerical_solution = odeint( + model, + np.array(initial_state), + t, + args=( + alpha, + beta, + gamma, + ), + ) + + assert np.allclose(numerical_solution, exact_solution) diff --git a/scvelo/datasets.py b/scvelo/datasets.py index 5dc265fb..484d4bd0 100644 --- a/scvelo/datasets.py +++ b/scvelo/datasets.py @@ -1,11 +1,14 @@ """Builtin Datasets. """ -from .read_load import read, load -from .preprocessing.utils import cleanup -from anndata import AnnData + import numpy as np import pandas as pd +from anndata import AnnData +from scanpy import read + +from scvelo.core import cleanup, SplicingDynamics +from .read_load import load url_datadir = "https://github.com/theislab/scvelo_notebooks/raw/master/" @@ -24,8 +27,6 @@ def toy_data(n_obs=None): Returns `adata` object """ - """Random samples from Dentate Gyrus. - """ adata_dg = dentategyrus() if n_obs is not None: @@ -40,7 +41,7 @@ def toy_data(n_obs=None): def dentategyrus(adjusted=True): """Dentate Gyrus neurogenesis. - Data from `Hochgerner et al. (2018) `_. + Data from `Hochgerner et al. (2018) `_. Dentate gyrus (DG) is part of the hippocampus involved in learning, episodic memory formation and spatial coding. The experiment from the developing DG comprises two @@ -58,7 +59,7 @@ def dentategyrus(adjusted=True): Returns ------- Returns `adata` object - """ + """ # noqa E501 if adjusted: filename = "data/DentateGyrus/10X43_1.h5ad" @@ -89,12 +90,16 @@ def dentategyrus(adjusted=True): def forebrain(): """Developing human forebrain. - Forebrain tissue of a week 10 embryo, focusing on glutamatergic neuronal lineage. + From `La Manno et al. (2018) `_. + + Forebrain tissue of a human week 10 embryo, focusing on glutamatergic neuronal + lineage, obtained from elective routine abortions (10 weeks post-conception). Returns ------- Returns `adata` object - """ + """ # noqa E501 + filename = "data/ForebrainGlut/hgForebrainGlut.loom" url = "http://pklab.med.harvard.edu/velocyto/hgForebrainGlut/hgForebrainGlut.loom" adata = read(filename, backup_url=url, cleanup=True, sparse=True, cache=True) @@ -103,9 +108,9 @@ def forebrain(): def pancreas(): - """Pancreatic endocrinogenesis + """Pancreatic endocrinogenesis. - Data from `Bastidas-Ponce et al. (2019) `_. + Data from `Bastidas-Ponce et al. (2019) `_. Pancreatic epithelial and Ngn3-Venus fusion (NVF) cells during secondary transition with transcriptome profiles sampled from embryonic day 15.5. @@ -121,7 +126,8 @@ def pancreas(): Returns ------- Returns `adata` object - """ + """ # noqa E501 + filename = "data/Pancreas/endocrinogenesis_day15.h5ad" url = f"{url_datadir}data/Pancreas/endocrinogenesis_day15.h5ad" adata = read(filename, backup_url=url, sparse=True, cache=True) @@ -132,6 +138,173 @@ def pancreas(): pancreatic_endocrinogenesis = pancreas # restore old conventions +def dentategyrus_lamanno(): + """Dentate Gyrus neurogenesis. + + From `La Manno et al. (2018) `_. + + The experiment from the developing mouse hippocampus comprises two time points + (P0 and P5) and reveals the complex manifold with multiple branching lineages + towards astrocytes, oligodendrocyte precursors (OPCs), granule neurons and pyramidal neurons. + + .. image:: https://user-images.githubusercontent.com/31883718/118401264-49bce380-b665-11eb-8678-e7570ede13d6.png + :width: 600px + + Returns + ------- + Returns `adata` object + """ # noqa E501 + filename = "data/DentateGyrus/DentateGyrus.loom" + url = "http://pklab.med.harvard.edu/velocyto/DentateGyrus/DentateGyrus.loom" + adata = read(filename, backup_url=url, sparse=True, cache=True) + adata.var_names_make_unique() + adata.obsm["X_tsne"] = np.column_stack([adata.obs["TSNE1"], adata.obs["TSNE2"]]) + adata.obs["clusters"] = adata.obs["ClusterName"] + cleanup(adata, clean="obs", keep=["Age", "clusters"]) + + adata.uns["clusters_colors"] = { + "RadialGlia": [0.95, 0.6, 0.1], + "RadialGlia2": [0.85, 0.3, 0.1], + "ImmAstro": [0.8, 0.02, 0.1], + "GlialProg": [0.81, 0.43, 0.72352941], + "OPC": [0.61, 0.13, 0.72352941], + "nIPC": [0.9, 0.8, 0.3], + "Nbl1": [0.7, 0.82, 0.6], + "Nbl2": [0.448, 0.85490196, 0.95098039], + "ImmGranule1": [0.35, 0.4, 0.82], + "ImmGranule2": [0.23, 0.3, 0.7], + "Granule": [0.05, 0.11, 0.51], + "CA": [0.2, 0.53, 0.71], + "CA1-Sub": [0.1, 0.45, 0.3], + "CA2-3-4": [0.3, 0.35, 0.5], + } + return adata + + +def gastrulation(): + """Mouse gastrulation. + + Data from `Pijuan-Sala et al. (2019) `_. + + Gastrulation represents a key developmental event during which embryonic pluripotent + cells diversify into lineage-specific precursors that will generate the adult organism. + + This data contains the erythrocyte lineage from Pijuan-Sala et al. (2019). + The experiment reveals the molecular map of mouse gastrulation and early organogenesis. + It comprises transcriptional profiles of 116,312 single cells from mouse embryos + collected at nine sequential time points ranging from 6.5 to 8.5 days post-fertilization. + It served to explore the complex events involved in the convergence of visceral and primitive streak-derived endoderm. + + .. image:: https://user-images.githubusercontent.com/31883718/130636066-3bae153e-1626-4d11-8f38-6efab5b81c1c.png + :width: 600px + + Returns + ------- + Returns `adata` object + """ # noqa E501 + filename = "data/Gastrulation/gastrulation.h5ad" + url = "https://ndownloader.figshare.com/files/28095525" + adata = read(filename, backup_url=url, sparse=True, cache=True) + adata.var_names_make_unique() + return adata + + +def gastrulation_e75(): + """Mouse gastrulation subset to E7.5. + + Data from `Pijuan-Sala et al. (2019) `_. + + Gastrulation represents a key developmental event during which embryonic pluripotent + cells diversify into lineage-specific precursors that will generate the adult organism. + + .. image:: https://user-images.githubusercontent.com/31883718/130636292-7f2a599b-ded4-4616-99d7-604d2f324531.png + :width: 600px + + Returns + ------- + Returns `adata` object + """ # noqa E501 + filename = "data/Gastrulation/gastrulation_e75.h5ad" + url = "https://ndownloader.figshare.com/files/30439878" + adata = read(filename, backup_url=url, sparse=True, cache=True) + adata.var_names_make_unique() + return adata + + +def gastrulation_erythroid(): + """Mouse gastrulation subset to erythroid lineage. + + Data from `Pijuan-Sala et al. (2019) `_. + + Gastrulation represents a key developmental event during which embryonic pluripotent + cells diversify into lineage-specific precursors that will generate the adult organism. + + .. image:: https://user-images.githubusercontent.com/31883718/118402002-40814600-b668-11eb-8bfc-dbece2b2b34e.png + :width: 600px + + Returns + ------- + Returns `adata` object + """ # noqa E501 + filename = "data/Gastrulation/erythroid_lineage.h5ad" + url = "https://ndownloader.figshare.com/files/27686871" + adata = read(filename, backup_url=url, sparse=True, cache=True) + adata.var_names_make_unique() + return adata + + +def bonemarrow(): + """Human bone marrow. + + Data from `Setty et al. (2019) `_. + + The bone marrow is the primary site of new blood cell production or haematopoiesis. + It is composed of hematopoietic cells, marrow adipose tissue, and supportive stromal cells. + + This dataset served to detect important landmarks of hematopoietic differentiation, to + identify key transcription factors that drive lineage fate choice and to closely track when cells lose plasticity. + + .. image:: https://user-images.githubusercontent.com/31883718/118402252-68bd7480-b669-11eb-9ef3-5f992b74a2d3.png + :width: 600px + + Returns + ------- + Returns `adata` object + """ # noqa E501 + filename = "data/BoneMarrow/human_cd34_bone_marrow.h5ad" + url = "https://ndownloader.figshare.com/files/27686835" + adata = read(filename, backup_url=url, sparse=True, cache=True) + adata.var_names_make_unique() + return adata + + +def pbmc68k(): + """Peripheral blood mononuclear cells. + + Data from `Zheng et al. (2017) `_. + + This experiment contains 68k peripheral blood mononuclear cells (PBMC) measured using 10X. + + PBMCs are a diverse mixture of highly specialized immune cells. + They originate from hematopoietic stem cells (HSCs) that reside in the bone marrow + and give rise to all blood cells of the immune system (hematopoiesis). + HSCs give rise to myeloid (monocytes, macrophages, granulocytes, megakaryocytes, dendritic cells, erythrocytes) + and lymphoid (T cells, B cells, NK cells) lineages. + + .. image:: https://user-images.githubusercontent.com/31883718/118402351-e1243580-b669-11eb-8256-4a49c299da3d.png + :width: 600px + + Returns + ------- + Returns `adata` object + """ # noqa E501 + filename = "data/PBMC/pbmc68k.h5ad" + url = "https://ndownloader.figshare.com/files/27686886" + adata = read(filename, backup_url=url, sparse=True, cache=True) + adata.var_names_make_unique() + return adata + + def simulation( n_obs=300, n_vars=None, @@ -159,20 +332,23 @@ def simulation( Returns ------- Returns `adata` object - """ - from .tools.dynamical_model_utils import vectorize, mRNA + """ # noqa E501 + + from .tools.dynamical_model_utils import vectorize np.random.seed(random_seed) def draw_poisson(n): - from random import uniform, seed # draw from poisson + from random import seed, uniform # draw from poisson seed(random_seed) t = np.cumsum([-0.1 * np.log(uniform(0, 1)) for _ in range(n - 1)]) return np.insert(t, 0, 0) # prepend t0=0 def simulate_dynamics(tau, alpha, beta, gamma, u0, s0, noise_model, noise_level): - ut, st = mRNA(tau, u0, s0, alpha, beta, gamma) + ut, st = SplicingDynamics( + alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0] + ).get_solution(tau, stacked=False) if noise_model == "normal": # add noise ut += np.random.normal( scale=noise_level * np.percentile(ut, 99) / 10, size=len(ut) @@ -237,10 +413,13 @@ def cycle(array, n_vars=None): U = np.zeros(shape=(len(t), n_vars)) S = np.zeros(shape=(len(t), n_vars)) + def is_list(x): + return isinstance(x, (tuple, list, np.ndarray)) + for i in range(n_vars): - alpha_i = alpha[i] if isinstance(alpha, (tuple, list, np.ndarray)) else alpha - beta_i = beta[i] if isinstance(beta, (tuple, list, np.ndarray)) else beta - gamma_i = gamma[i] if isinstance(gamma, (tuple, list, np.ndarray)) else gamma + alpha_i = alpha[i] if is_list(alpha) and len(alpha) != n_obs else alpha + beta_i = beta[i] if is_list(beta) and len(beta) != n_obs else beta + gamma_i = gamma[i] if is_list(gamma) and len(gamma) != n_obs else gamma tau, alpha_vec, u0_vec, s0_vec = vectorize( t, t_[i], alpha_i, beta_i, gamma_i, alpha_=alpha_, u0=0, s0=0 ) @@ -259,6 +438,13 @@ def cycle(array, n_vars=None): noise_level[i], ) + if is_list(alpha) and len(alpha) == n_obs: + alpha = np.nan + if is_list(beta) and len(beta) == n_obs: + beta = np.nan + if is_list(gamma) and len(gamma) == n_obs: + gamma = np.nan + obs = {"true_t": t.round(2)} var = { "true_t_": t_[:n_vars], diff --git a/scvelo/logging.py b/scvelo/logging.py index f140a48b..7e79c1c1 100644 --- a/scvelo/logging.py +++ b/scvelo/logging.py @@ -1,14 +1,15 @@ """Logging and Profiling """ - -from . import settings -from sys import stdout from datetime import datetime -from time import time as get_time from platform import python_version +from sys import stdout +from time import time as get_time + +from packaging.version import parse + from anndata.logging import get_memory_usage -from anndata.logging import print_memory_usage +from scvelo import settings _VERBOSITY_LEVELS_FROM_STRINGS = {"error": 0, "warn": 1, "info": 2, "hint": 3} @@ -89,7 +90,7 @@ def msg( if reset: try: settings._previous_memory_usage, _ = get_memory_usage() - except: + except Exception: pass settings._previous_time = get_time() if time: @@ -164,7 +165,7 @@ def __init__(self): def run(self): try: self.result = func(*args, **kwargs) - except: + except Exception: pass it = InterruptableThread() @@ -174,7 +175,7 @@ def run(self): def get_latest_pypi_version(): - from subprocess import check_output, CalledProcessError + from subprocess import CalledProcessError, check_output try: # needs to work offline as well result = check_output(["pip", "search", "scvelo"]) @@ -189,7 +190,7 @@ def check_if_latest_version(): latest_version = timeout( get_latest_pypi_version, timeout_duration=2, default="0.0.0" ) - if __version__.rsplit(".dev")[0] < latest_version.rsplit(".dev")[0]: + if parse(__version__.rsplit(".dev")[0]) < parse(latest_version.rsplit(".dev")[0]): warn( "There is a newer scvelo version available on PyPI:\n", "Your version: \t\t", @@ -297,7 +298,8 @@ def profiler(command, filename="profile.stats", n_stats=10): n_stats: int or None Number of top stats to show. """ - import cProfile, pstats + import cProfile + import pstats cProfile.run(command, filename) stats = pstats.Stats(filename).strip_dirs().sort_stats("time") diff --git a/scvelo/pl.py b/scvelo/pl.py index 3b819fc4..f6c2497a 100644 --- a/scvelo/pl.py +++ b/scvelo/pl.py @@ -1 +1 @@ -from scvelo.plotting import * +from scvelo.plotting import * # noqa diff --git a/scvelo/plotting/__init__.py b/scvelo/plotting/__init__.py index 13eeb61f..8ca74562 100644 --- a/scvelo/plotting/__init__.py +++ b/scvelo/plotting/__init__.py @@ -1,14 +1,40 @@ -from .scatter import scatter, umap, tsne, diffmap, phate, draw_graph, pca +from scanpy.plotting import paga_compare, rank_genes_groups + from .gridspec import gridspec +from .heatmap import heatmap +from .paga import paga +from .proportions import proportions +from .scatter import diffmap, draw_graph, pca, phate, scatter, tsne, umap +from .simulation import simulation +from .summary import summary +from .utils import hist, plot from .velocity import velocity from .velocity_embedding import velocity_embedding from .velocity_embedding_grid import velocity_embedding_grid from .velocity_embedding_stream import velocity_embedding_stream from .velocity_graph import velocity_graph -from .heatmap import heatmap -from .proportions import proportions -from .utils import hist, plot -from .simulation import simulation -from scanpy.plotting import paga_compare, rank_genes_groups -from .summary import summary -from .paga import paga + +__all__ = [ + "diffmap", + "draw_graph", + "gridspec", + "heatmap", + "hist", + "paga", + "paga_compare", + "pca", + "phate", + "plot", + "proportions", + "rank_genes_groups", + "scatter", + "simulation", + "summary", + "tsne", + "umap", + "velocity", + "velocity_embedding", + "velocity_embedding_grid", + "velocity_embedding_stream", + "velocity_graph", +] diff --git a/scvelo/plotting/docs.py b/scvelo/plotting/docs.py index 015e8d92..3cd0e44d 100644 --- a/scvelo/plotting/docs.py +++ b/scvelo/plotting/docs.py @@ -17,7 +17,8 @@ def dec(obj): doc_scatter = """\ basis: `str` or list of `str` (default: `None`) - Key for embedding. If not specified, use 'umap', 'tsne' or 'pca' (ordered by preference). + Key for embedding. If not specified, use 'umap', 'tsne' or 'pca' (ordered by + preference). vkey: `str` or list of `str` (default: `None`) Key for velocity / steady-state ratio to be visualized. color: `str`, list of `str` or `None` (default: `None`) @@ -44,10 +45,10 @@ def dec(obj): Specify percentile for continuous coloring. groups: `str` or list of `str` (default: `all groups`) Restrict to a few categories in categorical observation annotation. - Multiple categories can be passed as list with ['cluster_1', 'cluster_3'], + Multiple categories can be passed as list with ['cluster_1', 'cluster_3'], or as string with 'cluster_1, cluster_3'. sort_order: `bool` (default: `True`) - For continuous annotations used as color parameter, + For continuous annotations used as color parameter, plot data points with higher values on top of others. components: `str` or list of `str` (default: '1,2') For instance, ['1,2', '2,3']. @@ -60,11 +61,15 @@ def dec(obj): Legend font size. legend_fontweight: {'normal', 'bold', ...} (default: `None`) Legend font weight. A numeric value in range 0-1000 or a string. - Defaults to 'bold' if `legend_loc = 'on data'`, otherwise to 'normal'. + Defaults to 'bold' if `legend_loc = 'on data'`, otherwise to 'normal'. Available are `['light', 'normal', 'medium', 'semibold', 'bold', 'heavy', 'black']`. -legend_fontoutline +legend_fontoutline: float (default: `None`) Line width of the legend font outline in pt. Draws a white outline using the path effect :class:`~matplotlib.patheffects.withStroke`. +legend_align_text: bool or str (default: `None`) + Aligns the positions of the legend texts. Set the axis along which the best + alignment should be determined. This can be 'y' or True (vertically), + 'x' (horizontally), or 'xy' (best alignment in both directions). right_margin: `float` or list of `float` (default: `None`) Adjust the width of the space right of each plotting panel. left_margin: `float` or list of `float` (default: `None`) @@ -84,42 +89,46 @@ def dec(obj): ylim: tuple, e.g. [0,1] or `None` (default: `None`) Restrict y-limits of the axis. add_density: `bool` or `str` or `None` (default: `None`) - Whether to show density of values along x and y axes. + Whether to show density of values along x and y axes. Color of the density plot can also be passed as `str`. add_assignments: `bool` or `str` or `None` (default: `None`) - Whether to add assignments to the model curve. + Whether to add assignments to the model curve. Color of the assignments can also be passed as `str`. add_linfit: `bool` or `str` or `None` (default: `None`) - Whether to add linear regression fit to the data points. + Whether to add linear regression fit to the data points. Color of the line can also be passed as `str`. - Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`. + Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`. A colored regression line with intercept is obtained with `'intercept, blue'`. add_polyfit: `bool` or `str` or `int` or `None` (default: `None`) - Whether to add polynomial fit to the data points. Color of the polyfit plot can also - be passed as `str`. The degree of the polynomial fit can be passed as `int` - (default is 2 for quadratic fit). - Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`. + Whether to add polynomial fit to the data points. Color of the polyfit plot can also + be passed as `str`. The degree of the polynomial fit can be passed as `int` + (default is 2 for quadratic fit). + Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`. A colored regression line with intercept is obtained with `'intercept, blue'`. add_rug: `str` or `None` (default: `None`) - If categorical observation annotation (e.g. 'clusters') is given, a rugplot is + If categorical observation annotation (e.g. 'clusters') is given, a rugplot is attached to the x-axis showing the data membership to each of the categories. add_text: `str` (default: `None`) Text to be added to the plot, passed as `str`. -add_text_pos: `tuple`, e.g. [0.05, 0.95] (defaut: `[0.05, 0.95]`) +add_text_pos: `tuple`, e.g. [0.05, 0.95] (defaut: `[0.05, 0.95]`) Text position. Default is `[0.05, 0.95]`, positioning the text at top right corner. +add_margin: `float` (default: `None`) + A value between [-1, 1] to add (positive) and reduce (negative) figure margins. add_outline: `bool` or `str` (default: `False`) - Whether to show an outline around scatter plot dots. + Whether to show an outline around scatter plot dots. Alternatively a string of cluster names can be passed, e.g. 'cluster_1, clusters_3'. outline_width: tuple type `scalar` or `None` (default: `(0.3, 0.05)`) Width of the inner and outer outline outline_color: tuple of type `str` or `None` (default: `('black', 'white')`) Inner and outer matplotlib color of the outline n_convolve: `int` or `None` (default: `None`) - If `int` is given, data is smoothed by convolution + If `int` is given, data is smoothed by convolution along the x-axis with kernel size `n_convolve`. smooth: `bool` or `int` (default: `None`) - Whether to convolve/average the color values over the nearest neighbors. + Whether to convolve/average the color values over the nearest neighbors. If `int`, it specifies number of neighbors. +normalize_data: `bool` (default: `None`) + Whether to rescale values for x, y to [0,1]. rescale_color: `tuple` (default: `None`) Boundaries for color rescaling, e.g. [0, 1], setting min/max values of the colorbar. color_gradients: `str` or `np.ndarray` (default: `None`) @@ -139,7 +148,7 @@ def dec(obj): show: `bool`, optional (default: `None`) Show the plot, do not return axis. save: `bool` or `str`, optional (default: `None`) - If `True` or a `str`, save the figure. A string is appended to the default filename. + If `True` or a `str`, save the figure. A string is appended to the default filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}. ax: `matplotlib.Axes`, optional (default: `None`) A matplotlib axes object. Only works if plotting a single component.\ diff --git a/scvelo/plotting/gridspec.py b/scvelo/plotting/gridspec.py index ee440dfd..a0285c61 100644 --- a/scvelo/plotting/gridspec.py +++ b/scvelo/plotting/gridspec.py @@ -1,13 +1,14 @@ +from functools import partial + +import matplotlib.pyplot as pl + # todo: auto-complete and docs wrapper from .scatter import scatter +from .utils import get_figure_params, hist from .velocity_embedding import velocity_embedding from .velocity_embedding_grid import velocity_embedding_grid from .velocity_embedding_stream import velocity_embedding_stream from .velocity_graph import velocity_graph -from .utils import hist, get_figure_params - -import matplotlib.pyplot as pl -from functools import partial def _wraps_plot(wrapper, func): diff --git a/scvelo/plotting/heatmap.py b/scvelo/plotting/heatmap.py index c817de19..aeca9968 100644 --- a/scvelo/plotting/heatmap.py +++ b/scvelo/plotting/heatmap.py @@ -1,14 +1,16 @@ import numpy as np -import matplotlib.pyplot as pl -from matplotlib import rcParams -from matplotlib.colors import ColorConverter import pandas as pd -from pandas import unique, isnull from scipy.sparse import issparse -from .. import logging as logg -from .utils import is_categorical, interpret_colorkey, savefig_or_show, to_list -from .utils import set_colors_for_categorical_obs, strings_to_categoricals +from scvelo import logging as logg +from .utils import ( + interpret_colorkey, + is_categorical, + savefig_or_show, + set_colors_for_categorical_obs, + strings_to_categoricals, + to_list, +) def heatmap( @@ -25,8 +27,9 @@ def heatmap( colorbar=None, col_cluster=False, row_cluster=False, - figsize=(8, 4), + context=None, font_scale=None, + figsize=(8, 4), show=None, save=None, **kwargs, @@ -48,6 +51,8 @@ def heatmap( String denoting matplotlib color map. col_color: `str` or list of `str` (default: `None`) String denoting matplotlib color map to use along the columns. + palette: list of `str` (default: `'viridis'`) + Colors to use for plotting groups (categorical annotation). n_convolve: `int` or `None` (default: `30`) If `int` is given, data is smoothed by convolution along the x-axis with kernel size n_convolve. @@ -58,9 +63,13 @@ def heatmap( Wether to sort the expression values given by xkey. colorbar: `bool` or `None` (default: `None`) Whether to show colorbar. - {row,col}_cluster : bool, optional + {row,col}_cluster : `bool` or `None` If True, cluster the {rows, columns}. - figsize: tuple (default: `(7,5)`) + context : `None`, or one of {paper, notebook, talk, poster} + A dictionary of parameters or the name of a preconfigured set. + font_scale : float, optional + Scaling factor to scale the size of the font elements. + figsize: tuple (default: `(8,4)`) Figure size. show: `bool`, optional (default: `None`) Show the plot, do not return axis. @@ -73,7 +82,7 @@ def heatmap( Returns ------- - If `show==False` a `matplotlib.Axis` + If `show==False` a `matplotlib.Axis` """ import seaborn as sns @@ -98,7 +107,7 @@ def heatmap( for gene in var_names: try: df[gene] = np.convolve(df[gene].values, weights, mode="same") - except: + except Exception: pass # e.g. all-zero counts or nans cannot be convolved if sort: @@ -118,8 +127,6 @@ def heatmap( set_colors_for_categorical_obs(adata, col, palette) col_color.append(interpret_colorkey(adata, col)[np.argsort(time)]) - if font_scale is not None: - sns.set(font_scale=font_scale) if "dendrogram_ratio" not in kwargs: kwargs["dendrogram_ratio"] = ( 0.1 if row_cluster else 0, @@ -139,454 +146,21 @@ def heatmap( figsize=figsize, ) ) - try: - cm = sns.clustermap(df.T, **kwargs) - except: - logg.warn("Please upgrade seaborn with `pip install -U seaborn`.") - kwargs.pop("dendrogram_ratio") - kwargs.pop("cbar_pos") - cm = sns.clustermap(df.T, **kwargs) + + args = {} + if font_scale is not None: + args = {"font_scale": font_scale} + context = context or "notebook" + + with sns.plotting_context(context=context, **args): + try: + cm = sns.clustermap(df.T, **kwargs) + except Exception: + logg.warn("Please upgrade seaborn with `pip install -U seaborn`.") + kwargs.pop("dendrogram_ratio") + kwargs.pop("cbar_pos") + cm = sns.clustermap(df.T, **kwargs) savefig_or_show("heatmap", save=save, show=show) if show is False: return cm - - -def heatmap_deprecated( - adata, - var_names, - groups=None, - groupby=None, - annotations=None, - use_raw=False, - layers=None, - color_map=None, - color_map_anno=None, - colorbar=True, - row_width=None, - xlabel=None, - title=None, - figsize=None, - dpi=None, - show=True, - save=None, - ax=None, - **kwargs, -): - - """\ - Plot pseudotimeseries for genes as heatmap. - - Arguments - --------- - adata: :class:`~anndata.AnnData` - Annotated data matrix. - var_names: `str`, list of `str` - Names of variables to use for the plot. - groups: `str`, list of `str` or `None` (default: `None`) - Groups selected to plot. Must be an element of adata.obs[groupby]. - groupby: `str` or `None` (default: `None`) - Key in adata.obs. Indicates how to group the plot. - annotations: `str`, list of `str` or `None` (default: `None`) - Key in adata.obs. Annotations are plotted in the last row. - use_raw: `bool` (default: `False`) - If true, moments are used instead of raw data. - layers: `str`, list of `str` or `None` (default: `['X']`) - Selected layers. - color_map: `str`, list of `str` or `None` (default: `None`) - String denoting matplotlib color map for the heat map. - There must be one list entry for each layer. - color_map_anno: `str`, list of `str` or `None` (default: `None`) - String denoting matplotlib color map for the annotations. - There must be one list entry for each annotation. - colorbar: `bool` (default: `True`) - If True, a colormap for each layer is added on the right bottom corner. - row_width: `float` (default: `None`) - Constant width of all rows. - xlabel: - Label for the x-axis. - title: `str` or `None` (default: `None`) - Main plot title. - figsize: tuple (default: `(7,5)`) - Figure size. - dpi: `int` (default: 80) - Figure dpi. - show: `bool`, optional (default: `None`) - Show the plot, do not return axis. - save: `bool` or `str`, optional (default: `None`) - If `True` or a `str`, save the figure. - Infer the filetype if ending on {'.pdf', '.png', '.svg'}. - ax: `matplotlib.Axes`, optional (default: `None`) - A matplotlib axes object. Only works if plotting a single component. - - Returns - ------- - If `show==False` a `matplotlib.Axis` - """ - - # catch - if "velocity_pseudotime" not in adata.obs.keys(): - raise ValueError( - "A function requires computation of the pseudotime" - "for ordering at single-cell resolution" - ) - if layers is None: - layers = ["X"] - if annotations is None: - annotations = [] - if isinstance(var_names, str): - var_names = [var_names] - if len(var_names) == 0: - var_names = np.arange(adata.X.shape[1]) - if var_names.ndim == 2: - var_names = var_names[:, 0] - var_names = [name for name in var_names if name in adata.var_names] - if len(var_names) == 0: - raise ValueError( - "The specified var_names are all not" "contained in the adata.var_names." - ) - - if layers is None: - layers = ["X"] - if isinstance(layers, str): - layers = [layers] - layers = [layer for layer in layers if layer in adata.layers.keys() or layer == "X"] - if len(layers) == 0: - raise ValueError( - "The selected layers are not contained" "in adata.layers.keys()." - ) - if not use_raw: - layers = np.array(layers) - if "X" in layers: - layers[np.array([layer == "X" for layer in layers])] = "Ms" - if "spliced" in layers: - layers[np.array([layer == "spliced" for layer in layers])] = "Ms" - if "unspliced" in layers: - layers[np.array([layer == "unspliced" for layer in layers])] = "Ms" - layers = list(layers) - if "Ms" in layers and "Ms" not in adata.layers.keys(): - raise ValueError( - "Moments have to be computed before" "using this plot function." - ) - if "Mu" in layers and "Mu" not in adata.layers.keys(): - raise ValueError( - "Moments have to be computed before" "using this plot function." - ) - layers = unique(layers) - - # Number of rows to plot - tot_len = len(var_names) * len(layers) + len(annotations) - - # init main figure - figsize = rcParams["figure.figsize"] if figsize is None else figsize - if row_width is not None: - figsize[1] = row_width * tot_len - ax = pl.figure(figsize=figsize, dpi=dpi).gca() if ax is None else ax - ax.set_yticks([]) - ax.set_xticks([]) - - # groups bar - ax_bounds = ax.get_position().bounds - if groupby is not None: - # catch - if groupby not in adata.obs_keys(): - raise ValueError( - "The selected groupby is not contained" "in adata.obs_keys()." - ) - if groups is None: # Then use everything of that obs - groups = unique(adata.obs.clusters.values) - - imlist = [] - - for igroup, group in enumerate(groups): - for ivar, var in enumerate(var_names): - for ilayer, layer in enumerate(layers): - groups_axis = pl.axes( - [ - ax_bounds[0] + igroup * ax_bounds[2] / len(groups), - ax_bounds[1] - + ax_bounds[3] - * (tot_len - ivar * len(layers) - ilayer - 1) - / tot_len, - ax_bounds[2] / len(groups), - (ax_bounds[3] - ax_bounds[3] / tot_len * len(annotations)) - / (len(var_names) * len(layers)), - ] - ) - - # Get data to fill and reshape - dat = adata[:, var] - - idx_group = [adata.obs[groupby] == group] - idx_group = np.array(idx_group[0].tolist()) - idx_var = [vn in var_names for vn in adata.var_names] - idx_pt = np.array(adata.obs.velocity_pseudotime).argsort() - idx_pt = idx_pt[ - np.array( - isnull(np.array(dat.obs.velocity_pseudotime)[idx_pt]) - == False - ) - ] - - if layer == "X": - laydat = dat.X - else: - laydat = dat.layers[layer] - - t1, t2, t3 = idx_group, idx_var, idx_pt - t1 = t1[t3] - # laydat = laydat[:, t2] # select vars - laydat = laydat[t3] - laydat = laydat[t1] # select ordered groups - - if issparse(laydat): - laydat = laydat.A - - # transpose X for ordering in direction var_names: up->downwards - laydat = laydat.T[::-1] - laydat = laydat.reshape((1, len(laydat))) # ensure 1dimty - - # plot - im = groups_axis.imshow( - laydat, - aspect="auto", - interpolation="nearest", - cmap=color_map[ilayer], - ) - - # Frames - if ilayer == 0: - groups_axis.spines["bottom"].set_visible(False) - elif ilayer == len(layer) - 1: - groups_axis.spines["top"].set_visible(False) - else: - groups_axis.spines["top"].set_visible(False) - groups_axis.spines["bottom"].set_visible(False) - - # Further visuals - if igroup == 0: - if colorbar: - if len(layers) % 2 == 0: - if ilayer == len(layers) / 2 - 1: - pl.yticks([0.5], [var]) - else: - groups_axis.set_yticks([]) - else: - if ilayer == (len(layers) - 1) / 2: - pl.yticks([0], [var]) - else: - groups_axis.set_yticks([]) - else: - pl.yticks([0], [f"{layer} {var}"]) - else: - groups_axis.set_yticks([]) - - groups_axis.set_xticks([]) - if ilayer == 0 and ivar == 0: - groups_axis.set_title(f"{group}") - groups_axis.grid(False) - - # handle needed as mappable for colorbar - if igroup == len(groups) - 1: - imlist.append(im) - - # further annotations for each group - if annotations is not None: - for ianno, anno in enumerate(annotations): - anno_axis = pl.axes( - [ - ax_bounds[0] + igroup * ax_bounds[2] / len(groups), - ax_bounds[1] - + ax_bounds[3] / tot_len * (len(annotations) - ianno - 1), - ax_bounds[2] / len(groups), - ax_bounds[3] / tot_len, - ] - ) - if is_categorical(adata, anno): - colo = interpret_colorkey(adata, anno)[t3][t1] - colo.reshape(1, len(colo)) - mapper = np.vectorize(ColorConverter.to_rgb) - a = mapper(colo) - a = np.array(a).T - Y = a.reshape(1, len(colo), 3) - else: - Y = np.array(interpret_colorkey(adata, anno))[t3][t1] - Y = Y.reshape(1, len(Y)) - img = anno_axis.imshow( - Y, aspect="auto", interpolation="nearest", cmap=color_map_anno - ) - if igroup == 0: - anno_axis.set_yticklabels( - ["", anno, ""] - ) # , fontsize=ytick_fontsize) - anno_axis.tick_params(axis="both", which="both", length=0) - else: - anno_axis.set_yticklabels([]) - anno_axis.set_yticks([]) - anno_axis.set_xticks([]) - anno_axis.set_xticklabels([]) - anno_axis.grid(False) - pl.ylim([0.5, -0.5]) # center ticks - - else: # groupby is False - imlist = [] - for ivar, var in enumerate(var_names): - for ilayer, layer in enumerate(layers): - ax_bounds = ax.get_position().bounds - groups_axis = pl.axes( - [ - ax_bounds[0], - ax_bounds[1] - + ax_bounds[3] - * (tot_len - ivar * len(layers) - ilayer - 1) - / tot_len, - ax_bounds[2], - (ax_bounds[3] - ax_bounds[3] / tot_len * len(annotations)) - / (len(var_names) * len(layers)), - ] - ) - # Get data to fill - dat = adata[:, var] - idx = np.array(dat.obs.velocity_pseudotime).argsort() - idx = idx[ - np.array( - isnull(np.array(dat.obs.velocity_pseudotime)[idx]) == False - ) - ] - - if layer == "X": - laydat = dat.X - else: - laydat = dat.layers[layer] - laydat = laydat[idx] - if issparse(laydat): - laydat = laydat.A - - # transpose X for ordering in direction var_names: up->downwards - laydat = laydat.T[::-1] - laydat = laydat.reshape((1, len(laydat))) - - # plot - im = groups_axis.imshow( - laydat, - aspect="auto", - interpolation="nearest", - cmap=color_map[ilayer], - ) - imlist.append(im) - - # Frames - if ilayer == 0: - groups_axis.spines["bottom"].set_visible(False) - elif ilayer == len(layer) - 1: - groups_axis.spines["top"].set_visible(False) - else: - groups_axis.spines["top"].set_visible(False) - groups_axis.spines["bottom"].set_visible(False) - - # Further visuals - groups_axis.set_xticks([]) - groups_axis.grid(False) - pl.ylim([0.5, -0.5]) # center - if colorbar: - if len(layers) % 2 == 0: - if ilayer == len(layers) / 2 - 1: - pl.yticks([0.5], [var]) - else: - groups_axis.set_yticks([]) - else: - if ilayer == (len(layers) - 1) / 2: - pl.yticks([0], [var]) - else: - groups_axis.set_yticks([]) - else: - pl.yticks([0], [f"{layer} {var}"]) - - # further annotations bars - if annotations is not None: - for ianno, anno in enumerate(annotations): - anno_axis = pl.axes( - [ - ax_bounds[0], - ax_bounds[1] - + ax_bounds[3] / tot_len * (len(annotations) - ianno - 1), - ax_bounds[2], - ax_bounds[3] / tot_len, - ] - ) - dat = adata[:, var_names] - if is_categorical(dat, anno): - colo = interpret_colorkey(dat, anno)[idx] - colo.reshape(1, len(colo)) - mapper = np.vectorize(ColorConverter.to_rgb) - a = mapper(colo) - a = np.array(a).T - Y = a.reshape(1, len(idx), 3) - else: - Y = np.array(interpret_colorkey(dat, anno)[idx]).reshape( - 1, len(idx) - ) - img = anno_axis.imshow( - Y, aspect="auto", interpolation="nearest", cmap=color_map_anno - ) - - anno_axis.set_yticklabels(["", anno, ""]) # , fontsize=ytick_fontsize) - anno_axis.tick_params(axis="both", which="both", length=0) - anno_axis.grid(False) - anno_axis.set_xticks([]) - anno_axis.set_xticklabels([]) - pl.ylim([-0.5, +0.5]) - - # Colorbar - if colorbar: - if len(layers) > 1: - # I must admit, this part is chaotic - for ilayer, layer in enumerate(layers): - w = 0.015 * 10 / figsize[0] # 0.02 * ax_bounds[2] - x = ax_bounds[0] + ax_bounds[2] * 0.99 + 1.5 * w + w * 1.2 * ilayer - y = ax_bounds[1] - h = ax_bounds[3] * 0.3 - cbaxes = pl.axes([x, y, w, h]) - cb = pl.colorbar(mappable=imlist[ilayer], cax=cbaxes) - pl.text( - x - 40 * w, - y + h * 4, - layer, - rotation=45, - horizontalalignment="left", - verticalalignment="bottom", - ) - if ilayer == len(layers) - 1: - ext = abs(cb.vmin - cb.vmax) - cb.set_ticks([cb.vmin + 0.07 * ext, cb.vmax - 0.07 * ext]) - cb.ax.set_yticklabels(["Low", "High"]) # vertical colorbar - else: - cb.set_ticks([]) - else: - cbaxes = pl.axes( - [ - ax_bounds[0] + ax_bounds[2] + 0.01, - ax_bounds[1], - 0.02, - ax_bounds[3] * 0.3, - ] - ) - cb = pl.colorbar(mappable=im, cax=cbaxes) - cb.set_ticks([cb.vmin, cb.vmax]) - cb.ax.set_yticklabels(["Low", "High"]) - - if xlabel is None: - xlabel = "velocity" + " " + "pseudotime" - if title is not None: - ax.set_title(title, pad=30) - if len(annotations) == 0: - ax.set_xlabel(xlabel) - ax.xaxis.labelpad = 20 - - # set_label(xlabel, None, fontsize, basis) - # set_title(title, None, None, fontsize) - # update_axes(ax, fontsize) - - savefig_or_show("heatmap", dpi=dpi, save=save, show=show) - if not show: - return ax diff --git a/scvelo/plotting/paga.py b/scvelo/plotting/paga.py index 062c52a7..f078c7d6 100644 --- a/scvelo/plotting/paga.py +++ b/scvelo/plotting/paga.py @@ -1,20 +1,27 @@ -from .. import settings -from .. import logging as logg +import collections.abc as cabc +import random +from inspect import signature -from ..tools.utils import groups_to_bool -from ..tools.paga import get_igraph_from_adjacency -from .utils import default_basis, default_size, default_color, get_components -from .utils import make_unique_list, make_unique_valid_list, savefig_or_show -from .scatter import scatter -from .docs import doc_scatter, doc_params +import numpy as np +import matplotlib.pyplot as pl from matplotlib import rcParams from matplotlib.path import get_path_collection_extents -import matplotlib.pyplot as pl -import numpy as np -from inspect import signature -import collections.abc as cabc -import random + +from scvelo import logging as logg +from scvelo import settings +from scvelo.tools.paga import get_igraph_from_adjacency +from .docs import doc_params, doc_scatter +from .scatter import scatter +from .utils import ( + default_basis, + default_color, + default_size, + get_components, + make_unique_list, + make_unique_valid_list, + savefig_or_show, +) @doc_params(scatter=doc_scatter) @@ -280,11 +287,6 @@ def paga( size = default_size(adata) / 2 if size is None else size paga_groups = adata.uns["paga"]["groups"] - _adata = ( - adata[groups_to_bool(adata, groups, groupby=paga_groups)] - if groups is not None and paga_groups in adata.obs.keys() - else adata - ) if isinstance(node_colors, dict): paga_kwargs["colorbar"] = False @@ -376,7 +378,7 @@ def _compute_pos( if layout == "fa": try: import fa2 - except: + except Exception: logg.warn( "Package 'fa2' is not installed, falling back to layout 'fr'." "To use the faster and better ForceAtlas2 layout, " @@ -718,15 +720,18 @@ def _paga_graph( """scanpy/_paga_graph with some adjustments for directional graphs. To be moved back to scanpy once finalized. """ + import warnings + from pathlib import Path + import networkx as nx import pandas as pd import scipy - import warnings + from pandas.api.types import is_categorical_dtype + from matplotlib import patheffects from matplotlib.colors import is_color_like - from pathlib import Path + from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation - from pandas.api.types import is_categorical_dtype node_labels = labels # rename for clarity if ( @@ -825,8 +830,10 @@ def _paga_graph( and colors in adata.obs and is_categorical_dtype(adata.obs[colors]) ): - from scanpy._utils import compute_association_matrix_of_groups - from scanpy._utils import get_associated_colors_of_groups + from scanpy._utils import ( + compute_association_matrix_of_groups, + get_associated_colors_of_groups, + ) norm = "reference" if normalize_to_color else "prediction" _, asso_matrix = compute_association_matrix_of_groups( @@ -1042,7 +1049,9 @@ def transform_ax_coords(a, b): color_single.append("grey") fracs.append(1 - sum(fracs)) wedgeprops = dict(linewidth=0, edgecolor="k", antialiased=True) - pie_axs[count].pie(fracs, colors=color_single, wedgeprops=wedgeprops) + pie_axs[count].pie( + fracs, colors=color_single, wedgeprops=wedgeprops, normalize=True + ) if node_labels is not None: text_kwds.update(dict(verticalalignment="center", fontweight=fontweight)) text_kwds.update(dict(horizontalalignment="center", size=fontsize)) @@ -1087,6 +1096,6 @@ def getbb(sc, ax): result = get_path_collection_extents( transform.frozen(), [p], [t], [o], transOffset.frozen() ) - bboxes.append(result.inverse_transformed(ax.transData)) + bboxes.append(result.transformed(ax.transData.inverted())) return bboxes diff --git a/scvelo/plotting/palettes.py b/scvelo/plotting/palettes.py index 5962a923..5bd9b379 100644 --- a/scvelo/plotting/palettes.py +++ b/scvelo/plotting/palettes.py @@ -1,7 +1,10 @@ -"""Color palettes in addition to matplotlib's palettes.""" +from typing import Mapping, Sequence from matplotlib import cm, colors +"""Color palettes in addition to matplotlib's palettes.""" + + # Colorblindness adjusted vega_10 # See https://github.com/theislab/scanpy/issues/387 vega_10 = list(map(colors.to_hex, cm.tab10.colors)) @@ -80,13 +83,11 @@ # fmt: on -from typing import Mapping, Sequence - - def _plot_color_cylce(clists: Mapping[str, Sequence[str]]): import numpy as np + import matplotlib.pyplot as plt - from matplotlib.colors import ListedColormap, BoundaryNorm + from matplotlib.colors import BoundaryNorm, ListedColormap fig, axes = plt.subplots(nrows=len(clists)) # type: plt.Figure, plt.Axes fig.subplots_adjust(top=0.95, bottom=0.01, left=0.3, right=0.99) diff --git a/scvelo/plotting/proportions.py b/scvelo/plotting/proportions.py index 41f5ef66..420a57f7 100644 --- a/scvelo/plotting/proportions.py +++ b/scvelo/plotting/proportions.py @@ -1,7 +1,9 @@ -from ..preprocessing.utils import sum_var +import numpy as np import matplotlib.pyplot as pl -import numpy as np + +from scvelo.core import sum +from .utils import savefig_or_show def proportions( @@ -16,6 +18,7 @@ def proportions( dpi=100, use_raw=True, show=True, + save=None, ): """Plot pie chart of spliced/unspliced proprtions. @@ -43,22 +46,26 @@ def proportions( Use initial cell sizes before normalization and filtering. show: `bool` (default: True) Show the plot, do not return axis. + save: `bool` or `str`, optional (default: `None`) + If `True` or a `str`, save the figure. A string is appended to the default + filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}. Returns ------- Plots the proportions of abundances as pie chart. """ + # get counts per cell for each layer if layers is None: layers = ["spliced", "unspliced", "ambigious"] layers_keys = [key for key in layers if key in adata.layers.keys()] - counts_layers = [sum_var(adata.layers[key]) for key in layers_keys] + counts_layers = [sum(adata.layers[key], axis=1) for key in layers_keys] if use_raw: ikey, obs = "initial_size_", adata.obs counts_layers = [ - obs[ikey + l] if ikey + l in obs.keys() else c - for l, c in zip(layers_keys, counts_layers) + obs[ikey + layer_key] if ikey + layer_key in obs.keys() else c + for layer_key, c in zip(layers_keys, counts_layers) ] counts_total = np.sum(counts_layers, 0) counts_total += counts_total == 0 @@ -71,7 +78,10 @@ def proportions( ax = pl.subplot(gspec[0]) if highlight is None: highlight = "none" - explode = [0.1 if (l == highlight or l in highlight) else 0 for l in layers_keys] + explode = [ + 0.1 if (layer_key == highlight or layer_key in highlight) else 0 + for layer_key in layers_keys + ] autopct = "%1.0f%%" if add_labels_pie else None pie = ax.pie( @@ -151,7 +161,7 @@ def proportions( ax2.set_ylabel(groupby, fontweight="bold", fontsize=fontsize * 1.2) ax2.tick_params(axis="both", which="major", labelsize=fontsize) ax = [ax, ax2] - if show: - pl.show() - else: + + savefig_or_show("proportions", dpi=dpi, save=save, show=show) + if show is False: return ax diff --git a/scvelo/plotting/pseudotime.py b/scvelo/plotting/pseudotime.py index 6f52294d..33067212 100644 --- a/scvelo/plotting/pseudotime.py +++ b/scvelo/plotting/pseudotime.py @@ -1,4 +1,7 @@ -from .utils import * +import numpy as np + +import matplotlib.pyplot as pl +from matplotlib.ticker import MaxNLocator def principal_curve(adata): diff --git a/scvelo/plotting/scatter.py b/scvelo/plotting/scatter.py index 7e1f0562..16aeda8f 100644 --- a/scvelo/plotting/scatter.py +++ b/scvelo/plotting/scatter.py @@ -1,10 +1,19 @@ -from .docs import doc_scatter, doc_params -from .utils import * - from inspect import signature -import matplotlib.pyplot as pl + import numpy as np import pandas as pd +from pandas import unique + +import matplotlib.pyplot as pl +from matplotlib.colors import is_color_like + +from anndata import AnnData + +from scvelo import logging as logg +from scvelo import settings +from scvelo.preprocessing.neighbors import get_connectivities +from .docs import doc_params, doc_scatter +from .utils import * @doc_params(scatter=doc_scatter) @@ -34,6 +43,7 @@ def scatter( legend_fontsize=None, legend_fontweight=None, legend_fontoutline=None, + legend_align_text=None, xlabel=None, ylabel=None, title=None, @@ -48,11 +58,13 @@ def scatter( add_rug=None, add_text=None, add_text_pos=None, + add_margin=None, add_outline=None, outline_width=None, outline_color=None, n_convolve=None, smooth=None, + normalize_data=None, rescale_color=None, color_gradients=None, dpi=None, @@ -82,8 +94,9 @@ def scatter( Returns ------- - If `show==False` a `matplotlib.Axis` + If `show==False` a `matplotlib.Axis` """ + if adata is None and (x is not None and y is not None): adata = AnnData(np.stack([x, y]).T) @@ -97,7 +110,7 @@ def scatter( # keys for figures (fkeys) and multiple plots (mkeys) fkeys = ["adata", "show", "save", "groups", "ncols", "nrows", "wspace", "hspace"] - fkeys += ["ax", "kwargs"] + fkeys += ["add_margin", "ax", "kwargs"] mkeys = ["color", "layer", "basis", "components", "x", "y", "xlabel", "ylabel"] mkeys += ["title", "color_map", "add_text"] scatter_kwargs = {"show": False, "save": False} @@ -158,7 +171,7 @@ def scatter( nrows = int(np.ceil(len(multikey) / ncols)) else: ncols = int(np.ceil(len(multikey) / nrows)) - if not frameon: + if not frameon or frameon == "artist": lloc, llines = "legend_loc", "legend_loc_lines" if lloc in scatter_kwargs and scatter_kwargs[lloc] is None: scatter_kwargs[lloc] = "none" @@ -303,16 +316,12 @@ def scatter( basis = default_basis(adata) if linewidth is None: linewidth = 1 - if linecolor is None: - linecolor = "k" if frameon is None: frameon = True if not is_embedding else settings._frameon if isinstance(groups, str): groups = [groups] if use_raw is None and basis not in adata.var_names: use_raw = layer is None and adata.raw is not None - if projection == "3d": - from mpl_toolkits.mplot3d import Axes3D ax, show = get_ax(ax, show, figsize, dpi, projection) @@ -494,7 +503,7 @@ def scatter( try: c += rescale_color[0] - np.nanmin(c) c *= rescale_color[1] / np.nanmax(c) - except: + except Exception: logg.warn("Could not rescale colors. Pass a tuple, e.g. [0,1].") # set vmid to 0 if color values obtained from velocity expression @@ -521,12 +530,12 @@ def scatter( if len(x) != len(y): raise ValueError("x or y do not share the same dimension.") + if normalize_data: + x = (x - np.nanmin(x)) / (np.nanmax(x) - np.nanmin(x)) + y = (y - np.nanmin(x)) / (np.nanmax(y) - np.nanmin(y)) + if not isinstance(c, str): c = np.ravel(c) if len(np.ravel(c)) == len(x) else c - if len(c) != len(x): - c = "grey" - if not isinstance(color, str) or color != default_color(adata): - logg.warn("Invalid color key. Using grey instead.") # store original order of color values color_array, scatter_array = c, np.stack([x, y]).T @@ -576,6 +585,13 @@ def scatter( title = groups[0] else: # if nothing to be highlighted add_linfit, add_polyfit, add_density = None, None, None + else: + idx = None + + if not isinstance(c, str) and len(c) != len(x): + c = "grey" + if not isinstance(color, str) or color != default_color(adata): + logg.warn("Invalid color key. Using grey instead.") # check if higher value points should be plotted on top if not isinstance(c, str) and len(c) == len(x): @@ -583,7 +599,9 @@ def scatter( if sort_order and not is_categorical(adata, color): order = np.argsort(c) elif not sort_order and is_categorical(adata, color): - counts = get_value_counts(adata, color) + counts = get_value_counts( + adata[idx] if idx is not None else adata, color + ) np.random.seed(0) nums, p = np.arange(0, len(x)), counts / np.sum(counts) order = np.random.choice(nums, len(x), replace=False, p=p) @@ -592,8 +610,9 @@ def scatter( if isinstance(kwargs["s"], np.ndarray): # sort sizes if array-type kwargs["s"] = np.array(kwargs["s"])[order] + marker = kwargs.pop("marker", ".") smp = ax.scatter( - x, y, c=c, alpha=alpha, marker=".", zorder=zorder, **kwargs + x, y, c=c, alpha=alpha, marker=marker, zorder=zorder, **kwargs ) outline_dtypes = (list, tuple, np.ndarray, int, np.int_, str) @@ -655,6 +674,7 @@ def scatter( legend_fontweight, legend_fontsize, legend_fontoutline, + legend_align_text, groups, ) if add_density: @@ -710,7 +730,8 @@ def scatter( set_label(xlabel, ylabel, fontsize, basis, ax=ax) set_title(title, layer, color, fontsize, ax=ax) update_axes(ax, xlim, ylim, fontsize, is_embedding, frameon, figsize) - + if add_margin: + set_margin(ax, x, y, add_margin) if colorbar is not False: if not isinstance(c, str) and not is_categorical(adata, color): labelsize = fontsize * 0.75 if fontsize is not None else None @@ -744,6 +765,7 @@ def trimap(adata, **kwargs): ------- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. """ + return scatter(adata, basis="trimap", **kwargs) @@ -760,6 +782,7 @@ def umap(adata, **kwargs): ------- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. """ + return scatter(adata, basis="umap", **kwargs) @@ -776,6 +799,7 @@ def tsne(adata, **kwargs): ------- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. """ + return scatter(adata, basis="tsne", **kwargs) @@ -792,6 +816,7 @@ def diffmap(adata, **kwargs): ------- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. """ + return scatter(adata, basis="diffmap", **kwargs) @@ -808,6 +833,7 @@ def phate(adata, **kwargs): ------- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. """ + return scatter(adata, basis="phate", **kwargs) @@ -824,6 +850,7 @@ def draw_graph(adata, layout=None, **kwargs): ------- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. """ + if layout is None: layout = f"{adata.uns['draw_graph']['params']['layout']}" basis = f"draw_graph_{layout}" @@ -845,4 +872,5 @@ def pca(adata, **kwargs): ------- If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it. """ + return scatter(adata, basis="pca", **kwargs) diff --git a/scvelo/plotting/simulation.py b/scvelo/plotting/simulation.py index a76817a9..6f3ea5a9 100644 --- a/scvelo/plotting/simulation.py +++ b/scvelo/plotting/simulation.py @@ -1,10 +1,12 @@ -from ..tools.dynamical_model_utils import unspliced, mRNA, vectorize, tau_inv, get_vars -from .utils import make_dense - import numpy as np + import matplotlib.pyplot as pl from matplotlib import rcParams +from scvelo.core import SplicingDynamics +from scvelo.tools.dynamical_model_utils import get_vars, tau_inv, unspliced, vectorize +from .utils import make_dense + def get_dynamics(adata, key="fit", extrapolate=False, sorted=False, t=None): alpha, beta, gamma, scaling, t_ = get_vars(adata, key=key) @@ -18,8 +20,9 @@ def get_dynamics(adata, key="fit", extrapolate=False, sorted=False, t=None): t = adata.obs[f"{key}_t"].values if key == "true" else adata.layers[f"{key}_t"] tau, alpha, u0, s0 = vectorize(np.sort(t) if sorted else t, t_, alpha, beta, gamma) - ut, st = mRNA(tau, u0, s0, alpha, beta, gamma) - + ut, st = SplicingDynamics( + alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0] + ).get_solution(tau) return alpha, ut, st @@ -51,17 +54,26 @@ def compute_dynamics( tau, alpha, u0, s0 = vectorize(np.sort(t) if sort else t, t_, alpha, beta, gamma) - ut, st = mRNA(tau, u0, s0, alpha, beta, gamma) + ut, st = SplicingDynamics( + alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0] + ).get_solution(tau, stacked=False) ut, st = ut * scaling + u0_offset, st + s0_offset return alpha, ut, st def show_full_dynamics( - adata, basis, key="true", use_raw=False, linewidth=1, show_assignments=None, ax=None + adata, + basis, + key="true", + use_raw=False, + linewidth=1, + linecolor=None, + show_assignments=None, + ax=None, ): if ax is None: ax = pl.gca() - color = "grey" if key == "true" else "purple" + color = linecolor if linecolor else "grey" if key == "true" else "purple" linewidth = 0.5 * linewidth if key == "true" else linewidth label = "learned dynamics" if key == "fit" else "true dynamics" line = None @@ -115,7 +127,7 @@ def simulation( colors=None, **kwargs, ): - from ..tools.utils import make_dense + from scvelo.tools.utils import make_dense from .scatter import scatter if ykey is None: diff --git a/scvelo/plotting/summary.py b/scvelo/plotting/summary.py index 57c9ecfb..bf40914f 100644 --- a/scvelo/plotting/summary.py +++ b/scvelo/plotting/summary.py @@ -1,13 +1,12 @@ -from ..tools.dynamical_model import latent_time -from ..tools.velocity_pseudotime import velocity_pseudotime -from ..tools.rank_velocity_genes import rank_velocity_genes -from ..tools.score_genes_cell_cycle import score_genes_cell_cycle +import numpy as np +import pandas as pd -from .utils import make_unique_list +from scvelo.tools.dynamical_model import latent_time +from scvelo.tools.rank_velocity_genes import rank_velocity_genes +from scvelo.tools.score_genes_cell_cycle import score_genes_cell_cycle +from scvelo.tools.velocity_pseudotime import velocity_pseudotime from .gridspec import GridSpec - -import pandas as pd -import numpy as np +from .utils import make_unique_list def summary(adata, basis="umap", color="clusters", n_top_genes=12, var_names=None): diff --git a/scvelo/plotting/utils.py b/scvelo/plotting/utils.py index 5f30d3d1..7f6ec1b7 100644 --- a/scvelo/plotting/utils.py +++ b/scvelo/plotting/utils.py @@ -1,28 +1,27 @@ -from .. import settings -from .. import logging as logg -from .. import AnnData -from ..preprocessing.moments import get_connectivities -from ..tools.utils import strings_to_categoricals -from . import palettes - import os +from collections import abc + +from cycler import Cycler, cycler + import numpy as np import pandas as pd +from pandas import Index +from scipy import stats +from scipy.sparse import issparse + import matplotlib.pyplot as pl -from matplotlib.ticker import MaxNLocator -from mpl_toolkits.axes_grid1.inset_locator import inset_axes -from matplotlib.colors import is_color_like, ListedColormap, to_rgb, cnames +import matplotlib.transforms as tx +from matplotlib import patheffects, rcParams from matplotlib.collections import LineCollection +from matplotlib.colors import cnames, is_color_like, ListedColormap, to_rgb from matplotlib.gridspec import SubplotSpec -from matplotlib import patheffects -import matplotlib.transforms as tx -from matplotlib import rcParams -from pandas import unique, Index -from scipy.sparse import issparse -from scipy.stats import pearsonr -from cycler import Cycler, cycler -from collections import abc +from matplotlib.ticker import MaxNLocator +from mpl_toolkits.axes_grid1.inset_locator import inset_axes +from scvelo import logging as logg +from scvelo import settings +from scvelo.tools.utils import strings_to_categoricals +from . import palettes """helper functions""" @@ -48,7 +47,7 @@ def is_view(adata): def is_categorical(data, c=None): - from pandas.api.types import is_categorical as cat + from pandas.api.types import is_categorical_dtype as cat if c is None: return cat(data) # if data is categorical/array @@ -81,7 +80,9 @@ def is_list_of_str(key, max_len=None): def is_list_of_list(lst): - return lst is not None and any(isinstance(l, list) for l in lst) + return lst is not None and any( + isinstance(list_element, list) for list_element in lst + ) def is_list_of_int(lst): @@ -112,7 +113,9 @@ def get_ax(ax, show=None, figsize=None, dpi=None, projection=None): figsize, _ = get_figure_params(figsize) if ax is None: projection = "3d" if projection == "3d" else None - ax = pl.figure(None, figsize, dpi=dpi).gca(projection=projection) + _, ax = pl.subplots( + figsize=figsize, dpi=dpi, subplot_kw={"projection": projection} + ) elif isinstance(ax, SubplotSpec): geo = ax.get_geometry() if show is None: @@ -347,7 +350,7 @@ def default_color_map(adata, c): try: if np.min(c) in [-1, 0, False] and np.max(c) in [1, True]: cmap = "viridis_r" - except: + except Exception: cmap = None return cmap @@ -461,7 +464,7 @@ def set_artist_frame(ax, length=0.2, figsize=None): figsize = rcParams["figure.figsize"] if figsize is None else figsize aspect_ratio = figsize[0] / figsize[1] ax.xaxis.set_label_coords(length * 0.45, -0.035) - ax.yaxis.set_label_coords(-0.0175, length * aspect_ratio * 0.45) + ax.yaxis.set_label_coords(-0.025, length * aspect_ratio * 0.45) ax.xaxis.label.set_size(ax.xaxis.label.get_size() / 1.2) ax.yaxis.label.set_size(ax.yaxis.label.get_size() / 1.2) @@ -539,6 +542,7 @@ def set_legend( legend_fontweight, legend_fontsize, legend_fontoutline, + legend_align_text, groups, ): """ @@ -577,10 +581,14 @@ def set_legend( text = ax.text(x_pos, y_pos, label, path_effects=pe, **kwargs) texts.append(text) - # todo: adjust text positions to minimize overlaps, - # e.g. using https://github.com/Phlya/adjustText - # from adjustText import adjust_text - # adjust_text(texts, ax=ax) + if legend_align_text: + autoalign = "y" if legend_align_text is True else legend_align_text + try: + from adjustText import adjust_text as adj_text + + adj_text(texts, autoalign=autoalign, text_from_points=False, ax=ax) + except ImportError: + print("Please `pip install adjustText` for auto-aligning texts") else: for idx, label in enumerate(categories): @@ -599,6 +607,16 @@ def set_legend( ax.legend(loc=legend_loc, **kwargs) +def set_margin(ax, x, y, add_margin): + add_margin = 0.1 if add_margin is True else add_margin + xmin, xmax = np.min(x), np.max(x) + ymin, ymax = np.min(y), np.max(y) + xmargin = (xmax - xmin) * add_margin + ymargin = (ymax - ymin) * add_margin + ax.set_xlim(xmin - xmargin, xmax + xmargin) + ax.set_ylim(ymin - ymargin, ymax + ymargin) + + """get color values""" @@ -640,7 +658,7 @@ def interpret_colorkey(adata, c=None, layer=None, perc=None, use_raw=None): if is_categorical(adata, c): c = get_colors(adata, c) elif isinstance(c, str): - if is_color_like(c) and not c in adata.var_names: + if is_color_like(c) and c not in adata.var_names: pass elif c in adata.obs.keys(): # color by observation key c = adata.obs[c] @@ -649,15 +667,22 @@ def interpret_colorkey(adata, c=None, layer=None, perc=None, use_raw=None): ): # by gene if layer in adata.layers.keys(): if perc is None and any( - l in layer for l in ["spliced", "unspliced", "Ms", "Mu", "velocity"] + layer_name in layer + for layer_name in ["spliced", "unspliced", "Ms", "Mu", "velocity"] ): perc = [1, 99] # to ignore outliers in non-logarithmized layers c = adata.obs_vector(c, layer=layer) elif layer is not None and np.any( - [l in layer or "X" in layer for l in adata.layers.keys()] + [ + layer_name in layer or "X" in layer + for layer_name in adata.layers.keys() + ] ): l_array = np.hstack( - [adata.obs_vector(c, layer=l)[:, None] for l in adata.layers.keys()] + [ + adata.obs_vector(c, layer=layer)[:, None] + for layer in adata.layers.keys() + ] ) l_array = pd.DataFrame(l_array, columns=adata.layers.keys()) l_array.insert(0, "X", adata.obs_vector(c)) @@ -712,9 +737,10 @@ def set_colors_for_categorical_obs(adata, value_to_plot, palette=None): a sequence of colors (in a format that can be understood by matplotlib, eg. RGB, RGBS, hex, or a cycler object with key='color' """ - from .palettes import additional_colors from matplotlib.colors import to_hex + from .palettes import additional_colors + color_key = f"{value_to_plot}_colors" valid = True categories = adata.obs[value_to_plot].cat.categories @@ -724,19 +750,19 @@ def set_colors_for_categorical_obs(adata, value_to_plot, palette=None): palette = palettes.default_26 if length <= 28 else palettes.default_64 if isinstance(palette, str) and palette in adata.uns: palette = ( - adata.uns[palette].values() + [adata.uns[palette][c] for c in categories] if isinstance(adata.uns[palette], dict) else adata.uns[palette] ) - if palette is None and color_key in adata.uns: + color_keys = adata.uns[color_key] # Check if colors already exist in adata.uns and if they are a valid palette - _palette = [] - color_keys = ( - adata.uns[color_key].values() - if isinstance(adata.uns[color_key], dict) - else adata.uns[color_key] - ) + if isinstance(color_keys, np.ndarray) and isinstance(color_keys[0], dict): + adata.uns[color_key] = adata.uns[color_key][0] + # Flatten the dict to a list (mainly for anndata compatibilities) + if isinstance(adata.uns[color_key], dict): + adata.uns[color_key] = [adata.uns[color_key][c] for c in categories] + color_keys = adata.uns[color_key] for color in color_keys: if not is_color_like(color): # check if valid color translate to a hex color value @@ -760,6 +786,13 @@ def set_colors_for_categorical_obs(adata, value_to_plot, palette=None): colors_list = [to_hex(x) for x in cmap(np.linspace(0, 1, length))] else: + # check if palette is an array of length n_obs + if isinstance(palette, (list, np.ndarray)) or is_categorical(palette): + if len(adata.obs[value_to_plot]) == len(palette): + cats = pd.Categorical(adata.obs[value_to_plot]) + colors = pd.Categorical(palette) + if len(cats) == len(colors): + palette = dict(zip(cats, colors)) # check if palette is as dict and convert it to an ordered list if isinstance(palette, dict): palette = [palette[c] for c in categories] @@ -884,8 +917,9 @@ def rgb_custom_colormap(colors=None, alpha=None, N=256): Returns ------- - A ListedColormap + :class:`~matplotlib.colors.ListedColormap` """ + if colors is None: colors = ["royalblue", "white", "forestgreen"] c = [] @@ -906,9 +940,11 @@ def rgb_custom_colormap(colors=None, alpha=None, N=256): n = int(N / ints) for j in range(ints): + start = n * j + end = n * (j + 1) for i in range(3): - vals[n * j : n * (j + 1), i] = np.linspace(c[j][i], c[j + 1][i], n) - vals[n * j : n * (j + 1), -1] = np.linspace(alpha[j], alpha[j + 1], n) + vals[start:end, i] = np.linspace(c[j][i], c[j + 1][i], n) + vals[start:end, -1] = np.linspace(alpha[j], alpha[j + 1], n) return ListedColormap(vals) @@ -925,6 +961,8 @@ def savefig_or_show(writekey=None, show=None, dpi=None, ext=None, save=None): save = save.replace(try_ext, "") break # append it + if "/" in save: + writekey = None writekey = ( f"{writekey}_{save}" if writekey is not None and len(writekey) > 0 else save ) @@ -955,15 +993,17 @@ def savefig_or_show(writekey=None, show=None, dpi=None, ext=None, save=None): os.makedirs(settings.figdir) if ext is None: ext = settings.file_format_figs - filename = f"{settings.figdir}{settings.plot_prefix}{writekey}" + filepath = f"{settings.figdir}{settings.plot_prefix}{writekey}" + if "/" in writekey: + filepath = f"{writekey}" try: - filename += f"{settings.plot_suffix}.{ext}" + filename = filepath + f"{settings.plot_suffix}.{ext}" pl.savefig(filename, dpi=dpi, bbox_inches="tight") - except: # save as .png if .pdf is not feasible (e.g. specific streamplots) - logg.msg(f"figure cannot be saved as {ext}, using png instead.", v=1) - filename = f"{settings.figdir}{settings.plot_prefix}{writekey}" - filename += f"{settings.plot_suffix}.png" + except Exception: + # save as .png if .pdf is not feasible (e.g. specific streamplots) + filename = filepath + f"{settings.plot_suffix}.png" pl.savefig(filename, dpi=dpi, bbox_inches="tight") + logg.msg(f"figure cannot be saved as {ext}, using png instead.", v=1) logg.msg("saving figure to file", filename, v=1) if show: pl.show() @@ -1008,14 +1048,14 @@ def plot_linfit( if isinstance(add_linfit, str) else color if isinstance(color, str) - else "grey" + else "k" ) xnew = np.linspace(np.min(x), np.max(x) * 1.02) ax.plot(xnew, offset + xnew * slope, linewidth=linewidth, color=color) if add_legend: kwargs = dict(ha="left", va="top", fontsize=fontsize) bbox = dict(boxstyle="round", facecolor="wheat", alpha=0.2) - txt = r"$\rho = $" + f"{np.round(pearsonr(x, y)[0], 2)}" + txt = r"$\rho = $" + f"{np.round(stats.pearsonr(x, y)[0], 2)}" ax.text(0.05, 0.95, txt, transform=ax.transAxes, bbox=bbox, **kwargs) @@ -1053,7 +1093,7 @@ def plot_polyfit( if isinstance(add_polyfit, str) else color if isinstance(color, str) - else "grey" + else "k" ) xnew = np.linspace(np.min(x), np.max(x), num=100) @@ -1128,16 +1168,25 @@ def plot_velocity_fits( if "true_alpha" in adata.var.keys() and ( vkey is not None and "true_dynamics" in vkey ): - line, fit = show_full_dynamics(adata, basis, "true", use_raw, linewidth, ax=ax) + line, fit = show_full_dynamics( + adata, + basis, + key="true", + use_raw=use_raw, + linewidth=linewidth, + linecolor=linecolor, + ax=ax, + ) fits.append(fit) lines.append(line) if "fit_alpha" in adata.var.keys() and (vkey is None or "dynamics" in vkey): line, fit = show_full_dynamics( adata, basis, - "fit", - use_raw, - linewidth, + key="fit", + use_raw=use_raw, + linewidth=linewidth, + linecolor=linecolor, show_assignments=show_assignments, ax=ax, ) @@ -1359,7 +1408,7 @@ def hist( Returns ------- - If `show==False` a `matplotlib.Axis` + If `show==False` a `matplotlib.Axis` """ if ax is None: @@ -1418,11 +1467,11 @@ def hist( ax.fill_between(bins, 0, kde_bins, alpha=0.4, color=ci, label=li) ylim = np.min(kde_bins) if ylim is None else ylim if hist: - ci, li = colors[i], labels[i] if labels is not None else None + ci, li = colors[i], labels[i] if labels is not None and not kde else None kwargs.update({"color": ci, "label": li}) try: ax.hist(x_vals, bins=bins, alpha=alpha, density=normed, **kwargs) - except: + except Exception: ax.hist(x_vals, bins=bins, alpha=alpha, **kwargs) if xlabel is None: xlabel = "" @@ -1465,20 +1514,19 @@ def log_fmt(x, pos): pdf = [pdf] if isinstance(pdf, str) else pdf if pdf is not None: fits = [] - for i, pd in enumerate(pdf): - from scipy import stats - + for i, pdf_name in enumerate(pdf): xt = ax.get_xticks() xmin, xmax = min(xt), max(xt) lnspc = np.linspace(xmin, xmax, len(bins)) - if "(" in pd: # used passed parameters - args, pd = eval(pd[pd.rfind("(") :]), pd[: pd.rfind("(")] + if "(" in pdf_name: # used passed parameters + start = pdf_name.rfind("(") + args, pdf_name = eval(pdf_name[start:]), pdf_name[:start] else: # fit parameters - args = eval(f"stats.{pd}.fit(x_vals)") - pd_vals = eval(f"stats.{pd}.pdf(lnspc, *args)") - logg.info("Fitting", pd, np.round(args, 4), ".") - fit = ax.plot(lnspc, pd_vals, label=pd, color=colors[i]) + args = getattr(stats, pdf_name).fit(x_vals) + pd_vals = getattr(stats, pdf_name).pdf(lnspc, *args) + logg.info("Fitting", pdf_name, np.round(args, 4), ".") + fit = ax.plot(lnspc, pd_vals, label=pdf_name, color=colors[i]) fits.extend(fit) ax.legend(handles=fits, labels=pdf, fontsize=legend_fontsize) @@ -1505,7 +1553,7 @@ def plot( dpi=None, show=True, ): - ax = pl.figure(None, figsize, dpi=dpi) if ax is None else ax + ax, show = get_ax(ax, show, figsize, dpi) arrays = np.array(arrays) arrays = ( arrays if isinstance(arrays, (list, tuple)) or arrays.ndim > 1 else [arrays] @@ -1517,21 +1565,19 @@ def plot( for i, array in enumerate(arrays): X = array[np.isfinite(array)] X = X / np.max(X) if normalize else X - pl.plot(X, color=colors[i], label=labels[i] if labels is not None else None) + ax.plot(X, color=colors[i], label=labels[i] if labels is not None else None) - pl.xlabel(xlabel if xlabel is not None else "") - pl.ylabel(ylabel if xlabel is not None else "") + ax.set_xlabel(xlabel if xlabel is not None else "") + ax.set_ylabel(ylabel if xlabel is not None else "") if labels is not None: - pl.legend() + ax.legend() if xscale is not None: - pl.xscale(xscale) + ax.xscale(xscale) if yscale is not None: - pl.yscale(yscale) + ax.yscale(yscale) if not show: return ax - else: - pl.show() def fraction_timeseries( @@ -1627,10 +1673,10 @@ def make_unique_valid_list(adata, keys): def get_temporal_connectivities(adata, tkey, n_convolve=30): - from ..tools.velocity_graph import vals_to_csr - from ..tools.utils import normalize + from scvelo.tools.utils import normalize + from scvelo.tools.velocity_graph import vals_to_csr - # from ..tools.utils import get_indices + # from scvelo.tools.utils import get_indices # c_idx = get_indices(get_connectivities(adata, recurse_neighbors=True))[0] # lspace = np.linspace(0, len(c_idx) - 1, len(c_idx), dtype=int) # c_idx = np.hstack([c_idx, lspace[:, None]]) diff --git a/scvelo/plotting/velocity.py b/scvelo/plotting/velocity.py index 27f67a5f..d68229a9 100644 --- a/scvelo/plotting/velocity.py +++ b/scvelo/plotting/velocity.py @@ -1,21 +1,21 @@ -from .. import settings -from ..preprocessing.moments import second_order_moments -from ..tools.rank_velocity_genes import rank_velocity_genes +import numpy as np +import pandas as pd +from scipy.sparse import issparse + +import matplotlib.pyplot as pl +from matplotlib import rcParams + +from scvelo.preprocessing.moments import second_order_moments +from scvelo.tools.rank_velocity_genes import rank_velocity_genes from .scatter import scatter from .utils import ( - savefig_or_show, default_basis, default_size, get_basis, get_figure_params, + savefig_or_show, ) -import numpy as np -import pandas as pd -import matplotlib.pyplot as pl -from matplotlib import rcParams -from scipy.sparse import issparse - def velocity( adata, @@ -193,8 +193,8 @@ def velocity( ) # velocity and expression plots - for l, layer in enumerate(layers): - ax = pl.subplot(gs[v * nplts + l + 1]) + for layer_id, layer in enumerate(layers): + ax = pl.subplot(gs[v * nplts + layer_id + 1]) title = "expression" if layer in ["X", skey] else layer # _kwargs = {} if title == 'expression' else kwargs cmap = color_map diff --git a/scvelo/plotting/velocity_embedding.py b/scvelo/plotting/velocity_embedding.py index 26725633..b9d86ad4 100644 --- a/scvelo/plotting/velocity_embedding.py +++ b/scvelo/plotting/velocity_embedding.py @@ -1,13 +1,30 @@ -from ..tools.velocity_embedding import velocity_embedding as compute_velocity_embedding -from ..tools.utils import groups_to_bool -from .utils import * -from .scatter import scatter -from .docs import doc_scatter, doc_params +import numpy as np +import matplotlib.pyplot as pl from matplotlib import rcParams from matplotlib.colors import is_color_like -import matplotlib.pyplot as pl -import numpy as np + +from scvelo.tools.utils import groups_to_bool +from scvelo.tools.velocity_embedding import ( + velocity_embedding as compute_velocity_embedding, +) +from .docs import doc_params, doc_scatter +from .scatter import scatter +from .utils import ( + default_arrow, + default_basis, + default_color, + default_color_map, + default_size, + get_ax, + get_components, + get_figure_params, + interpret_colorkey, + make_unique_list, + make_unique_valid_list, + savefig_or_show, + velocity_embedding_changed, +) @doc_params(scatter=doc_scatter) @@ -70,8 +87,9 @@ def velocity_embedding( Returns ------- - `matplotlib.Axis` if `show==False` + `matplotlib.Axis` if `show==False` """ + if vkey == "all": lkeys = list(adata.layers.keys()) vkey = [key for key in lkeys if "velocity" in key and "_u" not in key] @@ -159,8 +177,6 @@ def velocity_embedding( return ax else: - if projection == "3d": - from mpl_toolkits.mplot3d import Axes3D ax, show = get_ax(ax, show, figsize, dpi, projection) color, layer, vkey, basis = colors[0], layers[0], vkeys[0], bases[0] diff --git a/scvelo/plotting/velocity_embedding_grid.py b/scvelo/plotting/velocity_embedding_grid.py index d81d77ce..c8bd86f9 100644 --- a/scvelo/plotting/velocity_embedding_grid.py +++ b/scvelo/plotting/velocity_embedding_grid.py @@ -1,14 +1,27 @@ -from ..tools.velocity_embedding import quiver_autoscale, velocity_embedding -from ..tools.utils import groups_to_bool -from .utils import * -from .scatter import scatter -from .docs import doc_scatter, doc_params - -from sklearn.neighbors import NearestNeighbors +import numpy as np from scipy.stats import norm as normal -from matplotlib import rcParams +from sklearn.neighbors import NearestNeighbors + import matplotlib.pyplot as pl -import numpy as np +from matplotlib import rcParams + +from scvelo.tools.utils import groups_to_bool +from scvelo.tools.velocity_embedding import quiver_autoscale, velocity_embedding +from .docs import doc_params, doc_scatter +from .scatter import scatter +from .utils import ( + default_arrow, + default_basis, + default_color, + default_size, + get_ax, + get_basis, + get_components, + get_figure_params, + make_unique_list, + savefig_or_show, + velocity_embedding_changed, +) def compute_velocity_on_grid( @@ -165,8 +178,9 @@ def velocity_embedding_grid( Returns ------- - `matplotlib.Axis` if `show==False` + `matplotlib.Axis` if `show==False` """ + basis = default_basis(adata, **kwargs) if basis is None else get_basis(adata, basis) if vkey == "all": lkeys = list(adata.layers.keys()) diff --git a/scvelo/plotting/velocity_embedding_stream.py b/scvelo/plotting/velocity_embedding_stream.py index 05a6d2a1..5a7acd89 100644 --- a/scvelo/plotting/velocity_embedding_stream.py +++ b/scvelo/plotting/velocity_embedding_stream.py @@ -1,13 +1,25 @@ -from ..tools.velocity_embedding import velocity_embedding -from ..tools.utils import groups_to_bool -from .utils import * -from .velocity_embedding_grid import compute_velocity_on_grid -from .scatter import scatter -from .docs import doc_scatter, doc_params +import numpy as np -from matplotlib import rcParams import matplotlib.pyplot as pl -import numpy as np +from matplotlib import rcParams + +from scvelo.tools.utils import groups_to_bool +from scvelo.tools.velocity_embedding import velocity_embedding +from .docs import doc_params, doc_scatter +from .scatter import scatter +from .utils import ( + default_basis, + default_color, + default_size, + get_ax, + get_basis, + get_components, + get_figure_params, + make_unique_list, + savefig_or_show, + velocity_embedding_changed, +) +from .velocity_embedding_grid import compute_velocity_on_grid @doc_params(scatter=doc_scatter) @@ -15,11 +27,15 @@ def velocity_embedding_stream( adata, basis=None, vkey="velocity", - density=None, + density=2, smooth=None, min_mass=None, cutoff_perc=None, arrow_color=None, + arrow_size=1, + arrow_style="-|>", + max_length=4, + integration_direction="both", linewidth=None, n_neighbors=None, recompute=None, @@ -62,8 +78,11 @@ def velocity_embedding_stream( --------- adata: :class:`~anndata.AnnData` Annotated data matrix. - density: `float` (default: 1) - Amount of velocities to show - 0 none to 1 all + density: `float` (default: 2) + Controls the closeness of streamlines. When density = 2 (default), the domain + is divided into a 60x60 grid, whereas density linearly scales this grid. + Each cell in the grid can have, at most, one traversing streamline. + For different densities in each direction, use a tuple (density_x, density_y). smooth: `float` (default: 0.5) Multiplication factor for scale in Gaussian kernel around grid point. min_mass: `float` (default: 1) @@ -73,18 +92,30 @@ def velocity_embedding_stream( If set, mask small velocities below a percentile threshold (between 0 and 100). linewidth: `float` (default: 1) Line width for streamplot. + arrow_color: `str` or 2D array (default: 'k') + The streamline color. If given an array, it must have the same shape as u and v. + arrow_size: `float` (default: 1) + Scaling factor for the arrow size. + arrow_style: `str` (default: '-|>') + Arrow style specification, '-|>' or '->'. + max_length: `float` (default: 4) + Maximum length of streamline in axes coordinates. + integration_direction: `str` (default: 'both') + Integrate the streamline in 'forward', 'backward' or 'both' directions. n_neighbors: `int` (default: None) Number of neighbors to consider around grid point. X: `np.ndarray` (default: None) - Embedding grid point coordinates + Embedding coordinates. Using `adata.obsm['X_umap']` per default. V: `np.ndarray` (default: None) - Embedding grid velocity coordinates + Embedding velocity coordinates. Using `adata.obsm['velocity_umap']` per default. + {scatter} Returns ------- - `matplotlib.Axis` if `show==False` + `matplotlib.Axis` if `show==False` """ + basis = default_basis(adata, **kwargs) if basis is None else get_basis(adata, basis) if vkey == "all": lkeys = list(adata.layers.keys()) @@ -148,6 +179,17 @@ def velocity_embedding_stream( "save": False, } + stream_kwargs = { + "linewidth": linewidth, + "density": density or 2, + "zorder": 3, + "arrow_color": arrow_color or "k", + "arrowsize": arrow_size or 1, + "arrowstyle": arrow_style or "-|>", + "maxlength": max_length or 4, + "integration_direction": integration_direction or "both", + } + multikey = ( colors if len(colors) > 1 @@ -175,11 +217,9 @@ def velocity_embedding_stream( ax.append( velocity_embedding_stream( adata, - density=density, size=size, smooth=smooth, n_neighbors=n_neighbors, - linewidth=linewidth, ax=pl.subplot(gs), color=colors[i] if len(colors) > 1 else color, layer=layers[i] if len(layers) > 1 else layer, @@ -188,6 +228,7 @@ def velocity_embedding_stream( X_grid=None if len(vkeys) > 1 else X_grid, V_grid=None if len(vkeys) > 1 else V_grid, **scatter_kwargs, + **stream_kwargs, **kwargs, ) ) @@ -197,19 +238,14 @@ def velocity_embedding_stream( else: ax, show = get_ax(ax, show, figsize, dpi) - density = 1 if density is None else density - stream_kwargs = { - "linewidth": linewidth, - "density": 2 * density, - "zorder": 3, - "color": "k" if arrow_color is None else arrow_color, - } + for arg in list(kwargs): if arg in stream_kwargs: stream_kwargs.update({arg: kwargs[arg]}) else: scatter_kwargs.update({arg: kwargs[arg]}) + stream_kwargs["color"] = stream_kwargs.pop("arrow_color", "k") ax.streamplot(X_grid[0], X_grid[1], V_grid[0], V_grid[1], **stream_kwargs) size = 8 * default_size(adata) if size is None else size diff --git a/scvelo/plotting/velocity_graph.py b/scvelo/plotting/velocity_graph.py index 9d5cba0f..e0612156 100644 --- a/scvelo/plotting/velocity_graph.py +++ b/scvelo/plotting/velocity_graph.py @@ -1,14 +1,21 @@ -from .. import settings -from ..preprocessing.neighbors import get_neighs -from ..tools.transition_matrix import transition_matrix -from .utils import savefig_or_show, default_basis, get_components -from .utils import get_basis, groups_to_bool, default_size -from .scatter import scatter -from .docs import doc_scatter, doc_params - import warnings + import numpy as np -from scipy.sparse import issparse, csr_matrix +from scipy.sparse import csr_matrix, issparse + +from scvelo import settings +from scvelo.preprocessing.neighbors import get_neighs +from scvelo.tools.transition_matrix import transition_matrix +from .docs import doc_params, doc_scatter +from .scatter import scatter +from .utils import ( + default_basis, + default_size, + get_basis, + get_components, + groups_to_bool, + savefig_or_show, +) @doc_params(scatter=doc_scatter) @@ -59,8 +66,9 @@ def velocity_graph( Returns ------- - `matplotlib.Axis` if `show==False` + `matplotlib.Axis` if `show==False` """ + basis = default_basis(adata, **kwargs) if basis is None else get_basis(adata, basis) kwargs.update( { @@ -76,7 +84,7 @@ def velocity_graph( ) ax = scatter(adata, layer=layer, color=color, size=size, ax=ax, zorder=0, **kwargs) - from networkx import Graph, DiGraph + from networkx import DiGraph, Graph if which_graph in {"neighbors", "connectivities"}: T = get_neighs(adata, "connectivities").copy() @@ -165,11 +173,12 @@ def draw_networkx_edges( ): """Draw the edges of the graph G. Adjusted from networkx.""" try: + from numbers import Number + import matplotlib.pyplot as plt - from matplotlib.colors import colorConverter, Colormap, Normalize from matplotlib.collections import LineCollection + from matplotlib.colors import colorConverter, Colormap, Normalize from matplotlib.patches import FancyArrowPatch - from numbers import Number except ImportError: raise ImportError("Matplotlib required for draw()") except RuntimeError: diff --git a/scvelo/pp.py b/scvelo/pp.py index 4db20b2e..c0870c14 100644 --- a/scvelo/pp.py +++ b/scvelo/pp.py @@ -1 +1 @@ -from scvelo.preprocessing import * +from scvelo.preprocessing import * # noqa diff --git a/scvelo/preprocessing/__init__.py b/scvelo/preprocessing/__init__.py index 863f64a9..1ba31917 100644 --- a/scvelo/preprocessing/__init__.py +++ b/scvelo/preprocessing/__init__.py @@ -1,4 +1,27 @@ -from .utils import show_proportions, cleanup, filter_genes, filter_genes_dispersion -from .utils import normalize_per_cell, filter_and_normalize, log1p, recipe_velocity -from .neighbors import pca, neighbors, remove_duplicate_cells from .moments import moments +from .neighbors import neighbors, pca, remove_duplicate_cells +from .utils import ( + cleanup, + filter_and_normalize, + filter_genes, + filter_genes_dispersion, + log1p, + normalize_per_cell, + recipe_velocity, + show_proportions, +) + +__all__ = [ + "cleanup", + "filter_and_normalize", + "filter_genes", + "filter_genes_dispersion", + "log1p", + "moments", + "neighbors", + "normalize_per_cell", + "pca", + "recipe_velocity", + "remove_duplicate_cells", + "show_proportions", +] diff --git a/scvelo/preprocessing/moments.py b/scvelo/preprocessing/moments.py index 312ad9ed..494eef17 100644 --- a/scvelo/preprocessing/moments.py +++ b/scvelo/preprocessing/moments.py @@ -1,10 +1,10 @@ -from .. import settings -from .. import logging as logg -from .utils import not_yet_normalized, normalize_per_cell -from .neighbors import neighbors, get_connectivities, get_n_neighs, verify_neighbors - -from scipy.sparse import csr_matrix, issparse import numpy as np +from scipy.sparse import csr_matrix, issparse + +from scvelo import logging as logg +from scvelo import settings +from .neighbors import get_connectivities, get_n_neighs, neighbors, verify_neighbors +from .utils import normalize_per_cell, not_yet_normalized def moments( @@ -48,12 +48,12 @@ def moments( Returns ------- - Returns or updates `adata` with the attributes Ms: `.layers` dense matrix with first order moments of spliced counts. Mu: `.layers` dense matrix with first order moments of unspliced counts. """ + adata = data.copy() if copy else data layers = [layer for layer in {"spliced", "unspliced"} if layer in adata.layers] @@ -114,6 +114,7 @@ def second_order_moments(adata, adjusted=False): Mss: Second order moments for spliced abundances Mus: Second order moments for spliced with unspliced abundances """ + if "neighbors" not in adata.uns: raise ValueError( "You need to run `pp.neighbors` first to compute a neighborhood graph." @@ -143,6 +144,7 @@ def second_order_moments_u(adata): ------- Muu: Second order moments for unspliced abundances """ + if "neighbors" not in adata.uns: raise ValueError( "You need to run `pp.neighbors` first to compute a neighborhood graph." @@ -186,10 +188,12 @@ def get_moments( Whether to compute centered (=variance) or uncentered second order moments. mode: `'connectivities'` or `'distances'` (default: `'connectivities'`) Distance metric to use for moment computation. + Returns ------- Mx: first or second order moments """ + if "neighbors" not in adata.uns: raise ValueError( "You need to run `pp.neighbors` first to compute a neighborhood graph." diff --git a/scvelo/preprocessing/neighbors.py b/scvelo/preprocessing/neighbors.py index a0063e31..b7f9e850 100644 --- a/scvelo/preprocessing/neighbors.py +++ b/scvelo/preprocessing/neighbors.py @@ -1,13 +1,17 @@ import warnings +from collections import Counter + import numpy as np +import pandas as pd +from scipy.sparse import coo_matrix, issparse + from anndata import AnnData from scanpy import Neighbors from scanpy.preprocessing import pca -from scipy.sparse import issparse, coo_matrix -from .utils import get_initial_size -from .. import logging as logg -from .. import settings +from scvelo import logging as logg +from scvelo import settings +from scvelo.core import get_initial_size def neighbors( @@ -71,16 +75,16 @@ def neighbors( Number of threads to be used (for runtime). copy Return a copy instead of writing to adata. + Returns ------- - Depending on `copy`, updates or returns `adata` with the following: - connectivities : sparse matrix (`.uns['neighbors']`, dtype `float32`) - Weighted adjacency matrix of the neighborhood graph of data + connectivities : `.obsp` + Sparse weighted adjacency matrix of the neighborhood graph of data points. Weights should be interpreted as connectivities. - distances : sparse matrix (`.uns['neighbors']`, dtype `float32`) - Instead of decaying weights, this stores distances for each pair of - neighbors. + distances : `.obsp` + Sparse matrix of distances for each pair of neighbors. """ + adata = adata.copy() if copy else adata if use_rep is None: @@ -179,7 +183,7 @@ def neighbors( adata.obsp["connectivities"] = neighbors.connectivities adata.uns["neighbors"]["connectivities_key"] = "connectivities" adata.uns["neighbors"]["distances_key"] = "distances" - except: + except Exception: adata.uns["neighbors"]["distances"] = neighbors.distances adata.uns["neighbors"]["connectivities"] = neighbors.connectivities @@ -427,15 +431,15 @@ def compute_connectivities_umap( def get_duplicate_cells(data): if isinstance(data, AnnData): X = data.X - l = list(np.sum(np.abs(data.obsm["X_pca"]), 1) + get_initial_size(data)) + lst = list(np.sum(np.abs(data.obsm["X_pca"]), 1) + get_initial_size(data)) else: X = data - l = list(np.sum(X, 1).A1 if issparse(X) else np.sum(X, 1)) + lst = list(np.sum(X, 1).A1 if issparse(X) else np.sum(X, 1)) - l_set = set(l) idx_dup = [] - if len(l_set) < len(l): - idx_dup = np.array([i for i, x in enumerate(l) if l.count(x) > 1]) + if len(set(lst)) < len(lst): + vals = [val for val, count in Counter(lst).items() if count > 1] + idx_dup = np.where(pd.Series(lst).isin(vals))[0] X_new = np.array(X[idx_dup].A if issparse(X) else X[idx_dup]) sorted_idx = np.lexsort(X_new.T) @@ -443,7 +447,7 @@ def get_duplicate_cells(data): row_mask = np.invert(np.append([True], np.any(np.diff(sorted_data, axis=0), 1))) idx = sorted_idx[row_mask] - idx_dup = idx_dup[idx] + idx_dup = np.array(idx_dup)[idx] return idx_dup @@ -456,4 +460,5 @@ def remove_duplicate_cells(adata): mask[idx_duplicates] = 0 logg.info("Removed", len(idx_duplicates), "duplicate cells.") adata._inplace_subset_obs(mask) - neighbors(adata) + if "neighbors" in adata.uns.keys(): + neighbors(adata) diff --git a/scvelo/preprocessing/utils.py b/scvelo/preprocessing/utils.py index ca6ff3e0..fca3db11 100644 --- a/scvelo/preprocessing/utils.py +++ b/scvelo/preprocessing/utils.py @@ -1,172 +1,124 @@ +import warnings + import numpy as np import pandas as pd from scipy.sparse import issparse from sklearn.utils import sparsefuncs + from anndata import AnnData -import warnings -from .. import logging as logg +from scvelo import logging as logg +from scvelo.core import cleanup as _cleanup +from scvelo.core import get_initial_size as _get_initial_size +from scvelo.core import get_size as _get_size +from scvelo.core import set_initial_size as _set_initial_size +from scvelo.core import show_proportions as _show_proportions +from scvelo.core import sum +from scvelo.core._anndata import verify_dtypes as _verify_dtypes def sum_obs(A): """summation over axis 0 (obs) equivalent to np.sum(A, 0)""" - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return A.sum(0).A1 if issparse(A) else np.sum(A, axis=0) + + warnings.warn( + "`sum_obs` is deprecated since scVelo v0.2.4 and will be removed in a future " + "version. Please use `sum(A, axis=0)` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return sum(A, axis=0) def sum_var(A): """summation over axis 1 (var) equivalent to np.sum(A, 1)""" - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return A.sum(1).A1 if issparse(A) else np.sum(A, axis=1) + warnings.warn( + "`sum_var` is deprecated since scVelo v0.2.4 and will be removed in a future " + "version. Please use `sum(A, axis=1)` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) -def show_proportions(adata, layers=None, use_raw=True): - """Proportions of spliced/unspliced abundances + return sum(A, axis=1) - Arguments - --------- - adata: :class:`~anndata.AnnData` - Annotated data matrix. - Returns - ------- - Prints the fractions of abundances. - """ - if layers is None: - layers = ["spliced", "unspliced", "ambigious"] - layers_keys = [key for key in layers if key in adata.layers.keys()] - counts_layers = [sum_var(adata.layers[key]) for key in layers_keys] - if use_raw: - size_key, obs = "initial_size_", adata.obs - counts_layers = [ - obs[size_key + l] if size_key + l in obs.keys() else c - for l, c in zip(layers_keys, counts_layers) - ] - - counts_per_cell_sum = np.sum(counts_layers, 0) - counts_per_cell_sum += counts_per_cell_sum == 0 - - mean_abundances = [ - np.mean(counts_per_cell / counts_per_cell_sum) - for counts_per_cell in counts_layers - ] +def show_proportions(adata, layers=None, use_raw=True): + warnings.warn( + "`scvelo.preprocessing.show_proportions` is deprecated since scVelo v0.2.4 " + "and will be removed in a future version. Please use " + "`scvelo.core.show_proportions` instead.", + DeprecationWarning, + stacklevel=2, + ) - print(f"Abundance of {layers_keys}: {np.round(mean_abundances, 2)}") + _show_proportions(adata=adata, layers=layers, use_raw=use_raw) def verify_dtypes(adata): - try: - _ = adata[:, 0] - except: - uns = adata.uns - adata.uns = {} - try: - _ = adata[:, 0] - logg.warn( - "Safely deleted unstructured annotations (adata.uns), \n" - "as these do not comply with permissible anndata datatypes." - ) - except: - logg.warn( - "The data might be corrupted. Please verify all annotation datatypes." - ) - adata.uns = uns - - -def cleanup(data, clean="layers", keep=None, copy=False): - """Deletes attributes not needed. - - Arguments - --------- - data: :class:`~anndata.AnnData` - Annotated data matrix. - clean: `str` or list of `str` (default: `layers`) - Which attributes to consider for freeing memory. - keep: `str` or list of `str` (default: None) - Which attributes to keep. - copy: `bool` (default: `False`) - Return a copy instead of writing to adata. - - Returns - ------- - Returns or updates `adata` with selection of attributes kept. - """ - adata = data.copy() if copy else data - verify_dtypes(adata) - - keep = list([keep] if isinstance(keep, str) else {} if keep is None else keep) - keep.extend(["spliced", "unspliced", "Ms", "Mu", "clusters", "neighbors"]) + warnings.warn( + "`scvelo.preprocessing.utils.verify_dtypes` is deprecated since scVelo v0.2.4 " + "and will be removed in a future version. Please use " + "`scvelo.core._anndata.verify_dtypes` instead.", + DeprecationWarning, + stacklevel=2, + ) - ann_dict = { - "obs": adata.obs_keys(), - "var": adata.var_keys(), - "uns": adata.uns_keys(), - "layers": list(adata.layers.keys()), - } + return _verify_dtypes(adata=adata) - if "all" not in clean: - ann_dict = {ann: values for (ann, values) in ann_dict.items() if ann in clean} - for (ann, values) in ann_dict.items(): - for value in values: - if value not in keep: - del getattr(adata, ann)[value] +def cleanup(data, clean="layers", keep=None, copy=False): + warnings.warn( + "`scvelo.preprocessing.cleanup` is deprecated since scVelo v0.2.4 and will be " + "removed in a future version. Please use `scvelo.core.cleanup` instead.", + DeprecationWarning, + stacklevel=2, + ) - return adata if copy else None + return _cleanup(data=data, clean=clean, keep=keep, copy=copy) def get_size(adata, layer=None): - X = adata.X if layer is None else adata.layers[layer] - return sum_var(X) + warnings.warn( + "`scvelo.preprocessing.utils.get_size` is deprecated since scVelo v0.2.4 and " + "will be removed in a future version. Please use `scvelo.core.get_size` " + "instead.", + DeprecationWarning, + stacklevel=2, + ) + + return _get_size(adata=adata, layer=layer) def set_initial_size(adata, layers=None): - if layers is None: - layers = ["spliced", "unspliced"] - verify_dtypes(adata) - layers = [ - layer - for layer in layers - if layer in adata.layers.keys() - and f"initial_size_{layer}" not in adata.obs.keys() - ] - for layer in layers: - adata.obs[f"initial_size_{layer}"] = get_size(adata, layer) - if "initial_size" not in adata.obs.keys(): - adata.obs["initial_size"] = get_size(adata) + warnings.warn( + "`scvelo.preprocessing.utils.set_initial_size` is deprecated since scVelo " + "v0.2.4 and will be removed in a future version. Please use " + "`scvelo.core.set_initial_size` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return _set_initial_size(adata=adata, layers=layers) def get_initial_size(adata, layer=None, by_total_size=None): - if by_total_size: - sizes = [ - adata.obs[f"initial_size_{layer}"] - for layer in {"spliced", "unspliced"} - if f"initial_size_{layer}" in adata.obs.keys() - ] - return np.sum(sizes, axis=0) - elif layer in adata.layers.keys(): - return ( - np.array(adata.obs[f"initial_size_{layer}"]) - if f"initial_size_{layer}" in adata.obs.keys() - else get_size(adata, layer) - ) - elif layer is None or layer == "X": - return ( - np.array(adata.obs["initial_size"]) - if "initial_size" in adata.obs.keys() - else get_size(adata) - ) - else: - return None + warnings.warn( + "`scvelo.preprocessing.get_initial_size` is deprecated since scVelo v0.2.4 and " + "will be removed in a future version. Please use " + "`scvelo.core.get_initial_size` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return _get_initial_size(adata=adata, layer=layer, by_total_size=by_total_size) def _filter(X, min_counts=None, min_cells=None, max_counts=None, max_cells=None): counts = ( - sum_obs(X) + sum(X, axis=0) if (min_counts is not None or max_counts is not None) - else sum_obs(X > 0) + else sum(X > 0, axis=0) ) lb = ( min_counts @@ -241,10 +193,11 @@ def filter_genes( ------- Filters the object and adds `n_counts` to `adata.var`. """ + adata = data.copy() if copy else data # set initial cell sizes before filtering - set_initial_size(adata) + _set_initial_size(adata) layers = [ layer for layer in ["spliced", "unspliced"] if layer in adata.layers.keys() @@ -429,8 +382,9 @@ def filter_genes_dispersion( If an AnnData `adata` is passed, returns or updates `adata` depending on \ `copy`. It filters the `adata` and adds the annotations """ + adata = data.copy() if copy else data - set_initial_size(adata) + _set_initial_size(adata) mean, var = materialize_as_ndarray(get_mean_var(adata.X)) @@ -558,13 +512,13 @@ def csr_vcorrcoef(X, y): def counts_per_cell_quantile(X, max_proportion_per_cell=0.05, counts_per_cell=None): if counts_per_cell is None: - counts_per_cell = sum_var(X) + counts_per_cell = sum(X, axis=1) gene_subset = np.all( X <= counts_per_cell[:, None] * max_proportion_per_cell, axis=0 ) if issparse(X): gene_subset = gene_subset.A1 - return sum_var(X[:, gene_subset]) + return sum(X[:, gene_subset], axis=1) def not_yet_normalized(X): @@ -621,6 +575,7 @@ def normalize_per_cell( ------- Returns or updates `adata` with normalized counts. """ + adata = data.copy() if copy else data if layers is None: layers = ["spliced", "unspliced"] @@ -633,7 +588,7 @@ def normalize_per_cell( if isinstance(counts_per_cell, str): if counts_per_cell not in adata.obs.keys(): - set_initial_size(adata, layers) + _set_initial_size(adata, layers) counts_per_cell = ( adata.obs[counts_per_cell].values if counts_per_cell in adata.obs.keys() @@ -648,9 +603,9 @@ def normalize_per_cell( counts = ( counts_per_cell if counts_per_cell is not None - else get_initial_size(adata, layer) + else _get_initial_size(adata, layer) if use_initial_size - else get_size(adata, layer) + else _get_size(adata, layer) ) if max_proportion_per_cell is not None and ( 0 < max_proportion_per_cell < 1 @@ -681,7 +636,7 @@ def normalize_per_cell( adata.var["gene_count_corr"] = np.round( csr_vcorrcoef(X.T, np.ravel((X > 0).sum(1))), 4 ) - except: + except Exception: pass else: logg.warn( @@ -689,7 +644,7 @@ def normalize_per_cell( "To enforce normalization, set `enforce=True`." ) - adata.obs["n_counts" if key_n_counts is None else key_n_counts] = get_size(adata) + adata.obs["n_counts" if key_n_counts is None else key_n_counts] = _get_size(adata) if len(modified_layers) > 0: logg.info("Normalized count data:", f"{', '.join(modified_layers)}.") @@ -705,10 +660,12 @@ def log1p(data, copy=False): Annotated data matrix. copy: `bool` (default: `False`) Return a copy of `adata` instead of updating it. + Returns ------- Returns or updates `adata` depending on `copy`. """ + adata = data.copy() if copy else data X = ( (adata.X.data if issparse(adata.X) else adata.X) @@ -793,6 +750,7 @@ def filter_and_normalize( ------- Returns or updates `adata` depending on `copy`. """ + adata = data.copy() if copy else data if "spliced" not in adata.layers.keys() or "unspliced" not in adata.layers.keys(): diff --git a/scvelo/read_load.py b/scvelo/read_load.py index 53b615e1..a883321e 100644 --- a/scvelo/read_load.py +++ b/scvelo/read_load.py @@ -1,15 +1,16 @@ -from . import logging as logg -from .preprocessing.utils import set_initial_size +import os +import warnings +from pathlib import Path +from urllib.request import urlretrieve -import os, re import numpy as np import pandas as pd -from pandas.api.types import is_categorical_dtype -from urllib.request import urlretrieve -from pathlib import Path -from scipy.sparse import issparse -from anndata import AnnData -from scanpy import read, read_loom + +from scvelo.core import clean_obs_names as _clean_obs_names +from scvelo.core import get_df as _get_df +from scvelo.core import merge as _merge +from scvelo.core._anndata import obs_df as _obs_df +from scvelo.core._anndata import var_df as _var_df def load(filename, backup_url=None, header="infer", index_col="infer", **kwargs): @@ -50,7 +51,8 @@ def load(filename, backup_url=None, header="infer", index_col="infer", **kwargs) else: raise ValueError( f"'{filename}' does not end on a valid extension.\n" - f"Please, provide one of the available extensions.\n{numpy_ext | pandas_ext}\n" + "Please, provide one of the available extensions.\n" + f"{numpy_ext | pandas_ext}\n" ) @@ -58,194 +60,50 @@ def load(filename, backup_url=None, header="infer", index_col="infer", **kwargs) def clean_obs_names(data, base="[AGTCBDHKMNRSVWY]", ID_length=12, copy=False): - """Clean up the obs_names. - - For example an obs_name 'sample1_AGTCdate' is changed to 'AGTC' of the sample - 'sample1_date'. The sample name is then saved in obs['sample_batch']. - The genetic codes are identified according to according to - https://www.neb.com/tools-and-resources/usage-guidelines/the-genetic-code. - - Arguments - --------- - adata: :class:`~anndata.AnnData` - Annotated data matrix. - base: `str` (default: `[AGTCBDHKMNRSVWY]`) - Genetic code letters to be identified. - ID_length: `int` (default: 12) - Length of the Genetic Codes in the samples. - copy: `bool` (default: `False`) - Return a copy instead of writing to adata. - - Returns - ------- - Returns or updates `adata` with the attributes - obs_names: list - updated names of the observations - sample_batch: `.obs` - names of the identified sample batches - """ - - def get_base_list(name, base): - base_list = base - while re.search(base_list + base, name) is not None: - base_list += base - if len(base_list) == 0: - raise ValueError("Encountered an invalid ID in obs_names: ", name) - return base_list - - adata = data.copy() if copy else data - - names = adata.obs_names - base_list = get_base_list(names[0], base) - - if len(np.unique([len(name) for name in adata.obs_names])) == 1: - start, end = re.search(base_list, names[0]).span() - newIDs = [name[start:end] for name in names] - start, end = 0, len(newIDs[0]) - for i in range(end - ID_length): - if np.any([ID[i] not in base for ID in newIDs]): - start += 1 - if np.any([ID[::-1][i] not in base for ID in newIDs]): - end -= 1 - - newIDs = [ID[start:end] for ID in newIDs] - prefixes = [names[i].replace(newIDs[i], "") for i in range(len(names))] - else: - prefixes, newIDs = [], [] - for name in names: - match = re.search(base_list, name) - newID = ( - re.search(get_base_list(name, base), name).group() - if match is None - else match.group() - ) - newIDs.append(newID) - prefixes.append(name.replace(newID, "")) - - adata.obs_names = newIDs - if len(prefixes[0]) > 0 and len(np.unique(prefixes)) > 1: - adata.obs["sample_batch"] = ( - pd.Categorical(prefixes) - if len(np.unique(prefixes)) < adata.n_obs - else prefixes - ) + warnings.warn( + "`scvelo.read_load.clean_obs_names` is deprecated since scVelo v0.2.4 and will " + "be removed in a future version. Please use `scvelo.core.clean_obs_names` " + "instead.", + DeprecationWarning, + stacklevel=2, + ) - adata.obs_names_make_unique() - return adata if copy else None + return _clean_obs_names(data=data, base=base, ID_length=ID_length, copy=copy) def merge(adata, ldata, copy=True): - """Merges two annotated data matrices. - - Arguments - --------- - adata: :class:`~anndata.AnnData` - Annotated data matrix (reference data set). - ldata: :class:`~anndata.AnnData` - Annotated data matrix (to be merged into adata). - - Returns - ------- - Returns a :class:`~anndata.AnnData` object - """ - adata.var_names_make_unique() - ldata.var_names_make_unique() - - if ( - "spliced" in ldata.layers.keys() - and "initial_size_spliced" not in ldata.obs.keys() - ): - set_initial_size(ldata) - elif ( - "spliced" in adata.layers.keys() - and "initial_size_spliced" not in adata.obs.keys() - ): - set_initial_size(adata) - - common_obs = pd.unique(adata.obs_names.intersection(ldata.obs_names)) - common_vars = pd.unique(adata.var_names.intersection(ldata.var_names)) - - if len(common_obs) == 0: - clean_obs_names(adata) - clean_obs_names(ldata) - common_obs = adata.obs_names.intersection(ldata.obs_names) - - if copy: - _adata = adata[common_obs].copy() - _ldata = ldata[common_obs].copy() - else: - adata._inplace_subset_obs(common_obs) - _adata, _ldata = adata, ldata[common_obs].copy() - - _adata.var_names_make_unique() - _ldata.var_names_make_unique() - - same_vars = len(_adata.var_names) == len(_ldata.var_names) and np.all( - _adata.var_names == _ldata.var_names + warnings.warn( + "`scvelo.read_load.merge` is deprecated since scVelo v0.2.4 and will be " + "removed in a future version. Please use `scvelo.core.merge` instead.", + DeprecationWarning, + stacklevel=2, ) - join_vars = len(common_vars) > 0 - - if join_vars and not same_vars: - _adata._inplace_subset_var(common_vars) - _ldata._inplace_subset_var(common_vars) - - for attr in _ldata.obs.keys(): - if attr not in _adata.obs.keys(): - _adata.obs[attr] = _ldata.obs[attr] - for attr in _ldata.obsm.keys(): - if attr not in _adata.obsm.keys(): - _adata.obsm[attr] = _ldata.obsm[attr] - for attr in _ldata.uns.keys(): - if attr not in _adata.uns.keys(): - _adata.uns[attr] = _ldata.uns[attr] - if join_vars: - for attr in _ldata.layers.keys(): - if attr not in _adata.layers.keys(): - _adata.layers[attr] = _ldata.layers[attr] - - if _adata.shape[1] == _ldata.shape[1]: - same_vars = len(_adata.var_names) == len(_ldata.var_names) and np.all( - _adata.var_names == _ldata.var_names - ) - if same_vars: - for attr in _ldata.var.keys(): - if attr not in _adata.var.keys(): - _adata.var[attr] = _ldata.var[attr] - for attr in _ldata.varm.keys(): - if attr not in _adata.varm.keys(): - _adata.varm[attr] = _ldata.varm[attr] - else: - raise ValueError("Variable names are not identical.") - return _adata if copy else None + return _merge(adata=adata, ldata=ldata, copy=True) def obs_df(adata, keys, layer=None): - lookup_keys = [k for k in keys if k in adata.var_names] - if len(lookup_keys) < len(keys): - logg.warn( - f"Keys {[k for k in keys if k not in adata.var_names]} " - f"were not found in `adata.var_names`." - ) + warnings.warn( + "`scvelo.read_load.obs_df` is deprecated since scVelo v0.2.4 and will be " + "removed in a future version. Please use `scvelo.core._anndata.obs_df` " + "instead.", + DeprecationWarning, + stacklevel=2, + ) - df = pd.DataFrame(index=adata.obs_names) - for l in lookup_keys: - df[l] = adata.obs_vector(l, layer=layer) - return df + return _obs_df(adata=adata, keys=keys, layer=layer) def var_df(adata, keys, layer=None): - lookup_keys = [k for k in keys if k in adata.obs_names] - if len(lookup_keys) < len(keys): - logg.warn( - f"Keys {[k for k in keys if k not in adata.obs_names]} " - f"were not found in `adata.obs_names`." - ) + warnings.warn( + "`scvelo.read_load.var_df` is deprecated since scVelo v0.2.4 and will be " + "removed in a future version. Please use `scvelo.core._anndata.var_df` " + "instead.", + DeprecationWarning, + stacklevel=2, + ) - df = pd.DataFrame(index=adata.var_names) - for l in lookup_keys: - df[l] = adata.var_vector(l, layer=layer) - return df + return _var_df(adata=adata, keys=keys, layer=layer) def get_df( @@ -258,161 +116,23 @@ def get_df( dropna="all", precision=None, ): - """Get dataframe for a specified adata key. - - Return values for specified key - (in obs, var, obsm, varm, obsp, varp, uns, or layers) as a dataframe. - - Arguments - ------ - adata - AnnData object or a numpy array to get values from. - keys - Keys from `.var_names`, `.obs_names`, `.var`, `.obs`, - `.obsm`, `.varm`, `.obsp`, `.varp`, `.uns`, or `.layers`. - layer - Layer of `adata` to use as expression values. - index - List to set as index. - columns - List to set as columns names. - sort_values - Wether to sort values by first column (sort_values=True) or a specified column. - dropna - Drop columns/rows that contain NaNs in all ('all') or in any entry ('any'). - precision - Set precision for pandas dataframe. - - Returns - ------- - A dataframe. - """ - if precision is not None: - pd.set_option("precision", precision) - - if isinstance(data, AnnData): - keys, keys_split = ( - keys.split("*") if isinstance(keys, str) and "*" in keys else (keys, None) - ) - keys, key_add = ( - keys.split("/") if isinstance(keys, str) and "/" in keys else (keys, None) - ) - keys = [keys] if isinstance(keys, str) else keys - key = keys[0] - - s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"] - d_keys = [ - data.obs.keys(), - data.var.keys(), - data.obsm.keys(), - data.varm.keys(), - data.uns.keys(), - data.layers.keys(), - ] - - if hasattr(data, "obsp") and hasattr(data, "varp"): - s_keys.extend(["obsp", "varp"]) - d_keys.extend([data.obsp.keys(), data.varp.keys()]) - - if keys is None: - df = data.to_df() - elif key in data.var_names: - df = obs_df(data, keys, layer=layer) - elif key in data.obs_names: - df = var_df(data, keys, layer=layer) - else: - if keys_split is not None: - keys = [ - k - for k in list(data.obs.keys()) + list(data.var.keys()) - if key in k and keys_split in k - ] - key = keys[0] - s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key] - if len(s_key) == 0: - raise ValueError(f"'{key}' not found in any of {', '.join(s_keys)}.") - if len(s_key) > 1: - logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.") - - s_key = s_key[-1] - df = getattr(data, s_key)[keys if len(keys) > 1 else key] - if key_add is not None: - df = df[key_add] - if index is None: - index = ( - data.var_names - if s_key == "varm" - else data.obs_names - if s_key in {"obsm", "layers"} - else None - ) - if index is None and s_key == "uns" and hasattr(df, "shape"): - key_cats = np.array( - [ - key - for key in data.obs.keys() - if is_categorical_dtype(data.obs[key]) - ] - ) - num_cats = [ - len(data.obs[key].cat.categories) == df.shape[0] - for key in key_cats - ] - if np.sum(num_cats) == 1: - index = data.obs[key_cats[num_cats][0]].cat.categories - if ( - columns is None - and len(df.shape) > 1 - and df.shape[0] == df.shape[1] - ): - columns = index - elif isinstance(index, str) and index in data.obs.keys(): - index = pd.Categorical(data.obs[index]).categories - if columns is None and s_key == "layers": - columns = data.var_names - elif isinstance(columns, str) and columns in data.obs.keys(): - columns = pd.Categorical(data.obs[columns]).categories - elif isinstance(data, pd.DataFrame): - if isinstance(keys, str) and "*" in keys: - keys, keys_split = keys.split("*") - keys = [k for k in data.columns if keys in k and keys_split in k] - df = data[keys] if keys is not None else data - else: - df = data - - if issparse(df): - df = np.array(df.A) - if columns is None and hasattr(df, "names"): - columns = df.names - - df = pd.DataFrame(df, index=index, columns=columns) - - if dropna: - df.replace("", np.nan, inplace=True) - how = dropna if isinstance(dropna, str) else "any" if dropna is True else "all" - df.dropna(how=how, axis=0, inplace=True) - df.dropna(how=how, axis=1, inplace=True) - - if sort_values: - sort_by = ( - sort_values - if isinstance(sort_values, str) and sort_values in df.columns - else df.columns[0] - ) - df = df.sort_values(by=sort_by, ascending=False) - - if hasattr(data, "var_names"): - if df.index[0] in data.var_names: - df.var_names = df.index - elif df.columns[0] in data.var_names: - df.var_names = df.columns - if hasattr(data, "obs_names"): - if df.index[0] in data.obs_names: - df.obs_names = df.index - elif df.columns[0] in data.obs_names: - df.obs_names = df.columns + warnings.warn( + "`scvelo.read_load.get_df` is deprecated since scVelo v0.2.4 and will be " + "removed in a future version. Please use `scvelo.core.get_df` instead.", + DeprecationWarning, + stacklevel=2, + ) - return df + return _get_df( + data=data, + keys=keys, + layer=layer, + index=index, + columns=columns, + sort_values=sort_values, + dropna=dropna, + precision=precision, + ) DataFrame = get_df @@ -435,6 +155,7 @@ def load_biomart(): df2.index = df2.pop("ensembl") df = pd.concat([df, df2]) + df = df.drop_duplicates() return df @@ -443,7 +164,7 @@ def convert_to_gene_names(ensembl_names=None): df = load_biomart() if ensembl_names is not None: if isinstance(ensembl_names, str): - ensembl_names = ensembl_names + ensembl_names = [ensembl_names] valid_names = [name for name in ensembl_names if name in df.index] if len(valid_names) > 0: df = df.loc[valid_names] diff --git a/scvelo/settings.py b/scvelo/settings.py index e56621f4..28097145 100644 --- a/scvelo/settings.py +++ b/scvelo/settings.py @@ -1,3 +1,11 @@ +import builtins +import warnings + +from cycler import cycler +from packaging.version import parse + +from matplotlib import cbook, cm, colors, rcParams + """Settings """ @@ -80,10 +88,6 @@ # Functions # -------------------------------------------------------------------------------- -from matplotlib import rcParams, cm, colors, cbook -from cycler import cycler -import warnings - warnings.filterwarnings("ignore", category=cbook.mplDeprecation) @@ -292,15 +296,8 @@ def set_figure_params( Only concerns the notebook/IPython environment; see `IPython.core.display.set_matplotlib_formats` for more details. """ - try: - import IPython - - if isinstance(ipython_format, str): - ipython_format = [ipython_format] - IPython.display.set_matplotlib_formats(*ipython_format) - except: - pass - + if ipython_format is not None: + _set_ipython(ipython_format) global _rcParams_style _rcParams_style = style global _vector_friendly @@ -332,6 +329,29 @@ def set_rcParams_defaults(): rcParams.update(rcParamsDefault) +def _set_ipython(ipython_format="png2x"): + if getattr(builtins, "__IPYTHON__", None): + try: + import IPython + + if isinstance(ipython_format, str): + ipython_format = [ipython_format] + if parse(IPython.__version__) < parse("7.23"): + IPython.display.set_matplotlib_formats(*ipython_format) + else: + from matplotlib_inline.backend_inline import set_matplotlib_formats + + set_matplotlib_formats(*ipython_format) + except ImportError: + pass + + +def _set_start_time(): + from time import time + + return time() + + # ------------------------------------------------------------------------------ # Private global variables & functions # ------------------------------------------------------------------------------ @@ -343,13 +363,6 @@ def set_rcParams_defaults(): _low_resolution_warning = True """Print warning when saving a figure with low resolution.""" - -def _set_start_time(): - from time import time - - return time() - - _start = _set_start_time() """Time when the settings module is first imported.""" diff --git a/scvelo/tl.py b/scvelo/tl.py index 1c69ddff..5d71ecb5 100644 --- a/scvelo/tl.py +++ b/scvelo/tl.py @@ -1 +1 @@ -from scvelo.tools import * +from scvelo.tools import * # noqa diff --git a/scvelo/tools/__init__.py b/scvelo/tools/__init__.py index 21d7091a..f46a862a 100644 --- a/scvelo/tools/__init__.py +++ b/scvelo/tools/__init__.py @@ -1,14 +1,51 @@ -from .velocity import velocity, velocity_genes -from .velocity_graph import velocity_graph +from scanpy.tools import diffmap, dpt, louvain, tsne, umap + +from .dynamical_model import ( + align_dynamics, + differential_kinetic_test, + DynamicsRecovery, + latent_time, + rank_dynamical_genes, + recover_dynamics, + recover_latent_time, +) +from .paga import paga +from .rank_velocity_genes import rank_velocity_genes, velocity_clusters +from .score_genes_cell_cycle import score_genes_cell_cycle +from .terminal_states import eigs, terminal_states from .transition_matrix import transition_matrix -from .velocity_embedding import velocity_embedding +from .velocity import velocity, velocity_genes from .velocity_confidence import velocity_confidence, velocity_confidence_transition -from .terminal_states import eigs, terminal_states -from .rank_velocity_genes import velocity_clusters, rank_velocity_genes +from .velocity_embedding import velocity_embedding +from .velocity_graph import velocity_graph from .velocity_pseudotime import velocity_map, velocity_pseudotime -from .dynamical_model import DynamicsRecovery, recover_dynamics, align_dynamics -from .dynamical_model import recover_latent_time, latent_time -from .dynamical_model import differential_kinetic_test, rank_dynamical_genes -from scanpy.tools import tsne, umap, diffmap, dpt, louvain -from .score_genes_cell_cycle import score_genes_cell_cycle -from .paga import paga + +__all__ = [ + "align_dynamics", + "differential_kinetic_test", + "diffmap", + "dpt", + "DynamicsRecovery", + "eigs", + "latent_time", + "louvain", + "paga", + "rank_dynamical_genes", + "rank_velocity_genes", + "recover_dynamics", + "recover_latent_time", + "score_genes_cell_cycle", + "terminal_states", + "transition_matrix", + "tsne", + "umap", + "velocity", + "velocity_clusters", + "velocity_confidence", + "velocity_confidence_transition", + "velocity_embedding", + "velocity_genes", + "velocity_graph", + "velocity_map", + "velocity_pseudotime", +] diff --git a/scvelo/tools/dynamical_model.py b/scvelo/tools/dynamical_model.py index b44aa7b7..cf24fe06 100644 --- a/scvelo/tools/dynamical_model.py +++ b/scvelo/tools/dynamical_model.py @@ -1,22 +1,18 @@ -from .. import settings -from .. import logging as logg -from ..preprocessing.moments import get_connectivities -from .utils import make_unique_list, test_bimodality -from .dynamical_model_utils import BaseDynamics, linreg, convolve, tau_inv, unspliced - -from typing import Any, Union, Callable, Optional, Sequence -from threading import Thread -from multiprocessing import Manager - -from joblib import Parallel, delayed -from scipy.sparse import issparse, spmatrix - import os + import numpy as np import pandas as pd +from scipy.optimize import minimize + import matplotlib.pyplot as pl from matplotlib import rcParams -from scipy.optimize import minimize + +from scvelo import logging as logg +from scvelo import settings +from scvelo.core import get_n_jobs, parallelize +from scvelo.preprocessing.moments import get_connectivities +from .dynamical_model_utils import BaseDynamics, convolve, linreg, tau_inv, unspliced +from .utils import make_unique_list, test_bimodality class DynamicsRecovery(BaseDynamics): @@ -67,11 +63,11 @@ def initialize(self): # initialize switching from u quantiles and alpha from s quantiles try: - tstat_u, pval_u, means_u = test_bimodality(u_w, kde=True) - tstat_s, pval_s, means_s = test_bimodality(s_w, kde=True) - except: + _, pval_u, means_u = test_bimodality(u_w, kde=True) + _, pval_s, means_s = test_bimodality(s_w, kde=True) + except Exception: logg.warn("skipping bimodality check for", self.gene) - tstat_u, tstat_s, pval_u, pval_s = 0, 0, 1, 1 + _, _, pval_u, pval_s = 0, 0, 1, 1 means_u, means_s = [0, 0], [0, 0] self.pval_steady = max(pval_u, pval_s) @@ -376,8 +372,10 @@ def recover_dynamics( --------- data: :class:`~anndata.AnnData` Annotated data matrix. - var_names: `str`, list of `str` (default: `'velocity_genes`) - Names of variables/genes to use for the fitting. + var_names: `str`, list of `str` (default: `'velocity_genes'`) + Names of variables/genes to use for the fitting. If `var_names='velocity_genes'` + but there is no column `'velocity_genes'` in `adata.var`, velocity genes are + estimated using the steady state model. n_top_genes: `int` or `None` (default: `None`) Number of top velocity genes to use for the dynamical model. max_iter:`int` (default: `10`) @@ -415,12 +413,27 @@ def recover_dynamics( n_jobs: `int` or `None` (default: `None`) Number of parallel jobs. backend: `str` (default: "loky") - Backend used for multiprocessing. See :class:`joblib.Parallel` for valid options. + Backend used for multiprocessing. See :class:`joblib.Parallel` for valid + options. Returns ------- - Returns or updates `adata` - """ + fit_alpha: `.var` + inferred transcription rates + fit_beta: `.var` + inferred splicing rates + fit_gamma: `.var` + inferred degradation rates + fit_t_: `.var` + inferred switching time points + fit_scaling: `.var` + internal variance scaling factor for un/spliced counts + fit_likelihood: `.var` + likelihood of model fit + fit_alignment_scaling: `.var` + scaling factor to align gene-wise latent times to a universal latent time + """ # noqa E501 + adata = data.copy() if copy else data n_jobs = get_n_jobs(n_jobs=n_jobs) @@ -485,7 +498,7 @@ def recover_dynamics( conn = get_connectivities(adata) if fit_connected_states else None - res = _parallelize( + res = parallelize( _fit_recovery, var_names, n_jobs, @@ -549,14 +562,17 @@ def recover_dynamics( if L: # is False if only one invalid / irrecoverable gene was given in var_names cur_len = adata.varm["loss"].shape[1] if "loss" in adata.varm.keys() else 2 - max_len = max(np.max([len(l) for l in L]), cur_len) if L else cur_len + max_len = max(np.max([len(loss) for loss in L]), cur_len) if L else cur_len loss = np.ones((adata.n_vars, max_len)) * np.nan if "loss" in adata.varm.keys(): loss[:, :cur_len] = adata.varm["loss"] loss[idx] = np.vstack( - [np.concatenate([l, np.ones(max_len - len(l)) * np.nan]) for l in L] + [ + np.concatenate([loss, np.ones(max_len - len(loss)) * np.nan]) + for loss in L + ] ) adata.varm["loss"] = loss @@ -617,9 +633,9 @@ def align_dynamics( Whether to remove outliers. copy: `bool` (default: `False`) Return a copy instead of writing to `adata`. - Returns + + Returns ------- - Returns or updates `adata` with the attributes alpha, beta, gamma, t_, alignment_scaling: `.var` aligned parameters fit_t, fit_tau, fit_tau_: `.layer` @@ -748,17 +764,18 @@ def latent_time( If not set, a overall transcriptional timescale of 20 hours is used as prior. copy: `bool` (default: `False`) Return a copy instead of writing to `adata`. - Returns + + Returns ------- - Returns or updates `adata` with the attributes latent_time: `.obs` latent time from learned dynamics for each cell - """ + """ # noqa E501 + adata = data.copy() if copy else data - from .utils import vcorrcoef, scale - from .dynamical_model_utils import root_time, compute_shared_time + from .dynamical_model_utils import compute_shared_time, root_time from .terminal_states import terminal_states + from .utils import scale, vcorrcoef from .velocity_graph import velocity_graph from .velocity_pseudotime import velocity_pseudotime @@ -793,7 +810,7 @@ def latent_time( idx_roots[pd.isnull(idx_roots)] = 0 if np.any([isinstance(ix, str) for ix in idx_roots]): idx_roots = np.array([isinstance(ix, str) for ix in idx_roots], dtype=int) - idx_roots = idx_roots.astype(np.float) > 1 - 1e-3 + idx_roots = idx_roots.astype(float) > 1 - 1e-3 if np.sum(idx_roots) > 0: roots = roots[idx_roots] else: @@ -811,7 +828,7 @@ def latent_time( idx_fates[pd.isnull(idx_fates)] = 0 if np.any([isinstance(ix, str) for ix in idx_fates]): idx_fates = np.array([isinstance(ix, str) for ix in idx_fates], dtype=int) - idx_fates = idx_fates.astype(np.float) > 1 - 1e-3 + idx_fates = idx_fates.astype(float) > 1 - 1e-3 if np.sum(idx_fates) > 0: fates = fates[idx_fates] else: @@ -919,8 +936,12 @@ def differential_kinetic_test( Returns ------- - Returns or updates `adata` - """ + fit_pvals_kinetics: `.varm` + P-values of competing kinetic for each group and gene + fit_diff_kinetics: `.var` + Groups that have differential kinetics for each gene. + """ # noqa E501 + adata = data.copy() if copy else data if "Ms" not in adata.layers.keys() or "Mu" not in adata.layers.keys(): @@ -981,7 +1002,7 @@ def differential_kinetic_test( if "fit_diff_kinetics" in adata.var.keys(): diff_kinetics = np.array(adata.var["fit_diff_kinetics"]) else: - diff_kinetics = np.empty(adata.n_vars, dtype="|U16") + diff_kinetics = np.empty(adata.n_vars, dtype="object") idx = [] progress = logg.ProgressReporter(len(var_names)) @@ -1014,7 +1035,7 @@ def differential_kinetic_test( "added \n" f" '{add_key}_diff_kinetics', " f"clusters displaying differential kinetics (adata.var)\n" - f" '{add_key}_pval_kinetics', " + f" '{add_key}_pvals_kinetics', " f"p-values of differential kinetics (adata.var)" ) @@ -1043,11 +1064,11 @@ def rank_dynamical_genes(data, n_genes=100, groupby=None, copy=False): Returns ------- - Returns or updates `data` with the attributes rank_dynamical_genes : `.uns` Structured array to be indexed by group id storing the gene names. Ordered according to scores. """ + from .dynamical_model_utils import get_divergence adata = data.copy() if copy else data @@ -1152,166 +1173,5 @@ def _fit_recovery( return idx, dms -# -*- coding: utf-8 -*- -"""Module used to parallelize model fitting.""" - -_msg_shown = False - - -def get_n_jobs(n_jobs): - if n_jobs is None or (n_jobs < 0 and os.cpu_count() + 1 + n_jobs <= 0): - return 1 - elif n_jobs > os.cpu_count(): - return os.cpu_count() - elif n_jobs < 0: - return os.cpu_count() + 1 + n_jobs - else: - return n_jobs - - -def _parallelize( - callback: Callable[[Any], Any], - collection: Union[spmatrix, Sequence[Any]], - n_jobs: Optional[int] = None, - n_split: Optional[int] = None, - unit: str = "", - as_array: bool = True, - use_ixs: bool = False, - backend: str = "loky", - extractor: Optional[Callable[[Any], Any]] = None, - show_progress_bar: bool = True, -) -> Union[np.ndarray, Any]: - """ - Parallelize function call over a collection of elements. - - Parameters - ---------- - callback - Function to parallelize. - collection - Sequence of items which to chunkify. - n_jobs - Number of parallel jobs. - n_split - Split :paramref:`collection` into :paramref:`n_split` chunks. - If `None`, split into :paramref:`n_jobs` chunks. - unit - Unit of the progress bar. - as_array - Whether to convert the results not :class:`numpy.ndarray`. - use_ixs - Whether to pass indices to the callback. - backend - Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options. - extractor - Function to apply to the result after all jobs have finished. - show_progress_bar - Whether to show a progress bar. - - Returns - ------- - :class:`numpy.ndarray` - Result depending on :paramref:`extractor` and :paramref:`as_array`. - """ - - if show_progress_bar: - try: - try: - from tqdm.notebook import tqdm - except ImportError: - from tqdm import tqdm_notebook as tqdm - import ipywidgets # noqa - except ImportError: - global _msg_shown - tqdm = None - - if not _msg_shown: - logg.warn( - "Unable to create progress bar. " - "Consider installing `tqdm` as `pip install tqdm` " - "and `ipywidgets` as `pip install ipywidgets`,\n" - "or disable the progress bar using `show_progress_bar=False`." - ) - _msg_shown = True - else: - tqdm = None - - def update(pbar, queue, n_total): - n_finished = 0 - while n_finished < n_total: - try: - res = queue.get() - except EOFError as e: - if not n_finished != n_total: - raise RuntimeError( - f"Finished only `{n_finished} out of `{n_total}` tasks.`" - ) from e - break - assert res in (None, (1, None), 1) # (None, 1) means only 1 job - if res == (1, None): - n_finished += 1 - if pbar is not None: - pbar.update() - elif res is None: - n_finished += 1 - elif pbar is not None: - pbar.update() - - if pbar is not None: - pbar.close() - - def wrapper(*args, **kwargs): - if pass_queue and show_progress_bar: - pbar = None if tqdm is None else tqdm(total=col_len, unit=unit) - queue = Manager().Queue() - thread = Thread(target=update, args=(pbar, queue, len(collections))) - thread.start() - else: - pbar, queue, thread = None, None, None - - res = Parallel(n_jobs=n_jobs, backend=backend)( - delayed(callback)( - *((i, cs) if use_ixs else (cs,)), - *args, - **kwargs, - queue=queue, - ) - for i, cs in enumerate(collections) - ) - - res = np.array(res) if as_array else res - if thread is not None: - thread.join() - - return res if extractor is None else extractor(res) - - col_len = collection.shape[0] if issparse(collection) else len(collection) - - if n_split is None: - n_split = get_n_jobs(n_jobs=n_jobs) - - if issparse(collection): - if n_split == collection.shape[0]: - collections = [collection[[ix], :] for ix in range(collection.shape[0])] - else: - step = collection.shape[0] // n_split - - ixs = [ - np.arange(i * step, min((i + 1) * step, collection.shape[0])) - for i in range(n_split) - ] - ixs[-1] = np.append( - ixs[-1], np.arange(ixs[-1][-1] + 1, collection.shape[0]) - ) - - collections = [collection[ix, :] for ix in filter(len, ixs)] - else: - collections = list(filter(len, np.array_split(collection, n_split))) - - pass_queue = not hasattr(callback, "py_func") # we'd be inside a numba function - - return wrapper - - def _flatten(iterable): return [i for it in iterable for i in it] diff --git a/scvelo/tools/dynamical_model_utils.py b/scvelo/tools/dynamical_model_utils.py index e1340a75..8d929d55 100644 --- a/scvelo/tools/dynamical_model_utils.py +++ b/scvelo/tools/dynamical_model_utils.py @@ -1,17 +1,19 @@ -from .. import logging as logg -from ..preprocessing.moments import get_connectivities -from .utils import make_dense, round +import warnings + +import numpy as np +import pandas as pd +from scipy.sparse import issparse +from scipy.stats.distributions import chi2, norm import matplotlib as mpl +import matplotlib.gridspec as gridspec import matplotlib.pyplot as pl from matplotlib import rcParams -import matplotlib.gridspec as gridspec -from scipy.stats.distributions import chi2, norm -from scipy.sparse import issparse -import warnings -import pandas as pd -import numpy as np +from scvelo import logging as logg +from scvelo.core import clipped_log, invert, SplicingDynamics +from scvelo.preprocessing.moments import get_connectivities +from .utils import make_dense, round exp = np.exp @@ -20,14 +22,28 @@ def log(x, eps=1e-6): # to avoid invalid values for log. - return np.log(np.clip(x, eps, 1 - eps)) + + warnings.warn( + "`clipped_log` is deprecated since scVelo v0.2.4 and will be removed in a " + "future version. Please use `clipped_log(x, eps=1e-6)` from `scvelo/core/`" + "instead.", + DeprecationWarning, + stacklevel=2, + ) + + return clipped_log(x, lb=0, ub=1, eps=1e-6) def inv(x): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_inv = 1 / x * (x != 0) - return x_inv + + warnings.warn( + "`inv` is deprecated since scVelo v0.2.4 and will be removed in a future " + "version. Please use `invert(x)` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return invert(x) def normalize(X, axis=0, min_confidence=None): @@ -112,17 +128,23 @@ def unspliced(tau, u0, alpha, beta): def spliced(tau, s0, u0, alpha, beta, gamma): - c = (alpha - u0 * beta) * inv(gamma - beta) + c = (alpha - u0 * beta) * invert(gamma - beta) expu, exps = exp(-beta * tau), exp(-gamma * tau) return s0 * exps + alpha / gamma * (1 - exps) + c * (exps - expu) def mRNA(tau, u0, s0, alpha, beta, gamma): - expu, exps = exp(-beta * tau), exp(-gamma * tau) - expus = (alpha - u0 * beta) * inv(gamma - beta) * (exps - expu) - u = u0 * expu + alpha / beta * (1 - expu) - s = s0 * exps + alpha / gamma * (1 - exps) + expus - return u, s + + warnings.warn( + "`mRNA` is deprecated since scVelo v0.2.4 and will be removed in a future " + "version. Please use `SplicingDynamics` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution( + tau, stacked=False + ) def adjust_increments(tau, tau_=None): @@ -166,13 +188,15 @@ def tau_inv(u, s=None, u0=None, s0=None, alpha=None, beta=None, gamma=None): any_invus = np.any(inv_us) and s is not None if any_invus: # tau_inv(u, s) - beta_ = beta * inv(gamma - beta) + beta_ = beta * invert(gamma - beta) xinf = alpha / gamma - beta_ * (alpha / beta) - tau = -1 / gamma * log((s - beta_ * u - xinf) / (s0 - beta_ * u0 - xinf)) + tau = ( + -1 / gamma * clipped_log((s - beta_ * u - xinf) / (s0 - beta_ * u0 - xinf)) + ) if any_invu: # tau_inv(u) uinf = alpha / beta - tau_u = -1 / beta * log((u - uinf) / (u0 - uinf)) + tau_u = -1 / beta * clipped_log((u - uinf) / (u0 - uinf)) tau = tau_u * inv_u + tau * inv_us if any_invus else tau_u return tau @@ -189,9 +213,10 @@ def assign_tau( num = np.clip(int(len(u) / 5), 200, 500) tpoints = np.linspace(0, t_, num=num) tpoints_ = np.linspace(0, t0, num=num)[1:] - - xt = np.vstack(mRNA(tpoints, 0, 0, alpha, beta, gamma)).T - xt_ = np.vstack(mRNA(tpoints_, u0_, s0_, 0, beta, gamma)).T + xt = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(tpoints) + xt_ = SplicingDynamics( + alpha=0, beta=beta, gamma=gamma, initial_state=[u0_, s0_] + ).get_solution(tpoints_) # assign time points (oth. projection onto 'on' and 'off' curve) tau = tpoints[ @@ -253,7 +278,9 @@ def compute_divergence( """ # set tau, tau_ if u0_ is None or s0_ is None: - u0_, s0_ = mRNA(t_, 0, 0, alpha, beta, gamma) + u0_, s0_ = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution( + t_, stacked=False + ) if tau is None or tau_ is None or t_ is None: tau, tau_, t_ = assign_tau( u, s, alpha, beta, gamma, t_, u0_, s0_, assignment_mode @@ -263,8 +290,12 @@ def compute_divergence( # adjust increments of tau, tau_ to avoid meaningless jumps if constraint_time_increments: - ut, st = mRNA(tau, 0, 0, alpha, beta, gamma) - ut_, st_ = mRNA(tau_, u0_, s0_, 0, beta, gamma) + ut, st = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution( + tau, stacked=False + ) + ut_, st_ = SplicingDynamics( + alpha=0, beta=beta, gamma=gamma, initial_state=[u0_, s0_] + ).get_solution(tau_, stacked=False) distu, distu_ = (u - ut) / std_u, (u - ut_) / std_u dists, dists_ = (s - st) / std_s, (s - st_) / std_s @@ -288,8 +319,12 @@ def compute_divergence( tau_[off] = adjust_increments(tau_[off]) # compute induction/repression state distances - ut, st = mRNA(tau, 0, 0, alpha, beta, gamma) - ut_, st_ = mRNA(tau_, u0_, s0_, 0, beta, gamma) + ut, st = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution( + tau, stacked=False + ) + ut_, st_ = SplicingDynamics( + alpha=0, beta=beta, gamma=gamma, initial_state=[u0_, s0_] + ).get_solution(tau_, stacked=False) if ut.ndim > 1 and ut.shape[1] == 1: ut = np.ravel(ut) @@ -542,7 +577,9 @@ def compute_divergence( o = (res < 2) * (t < t_) o_ = (res < 2) * (t >= t_) tau, alpha, u0, s0 = vectorize(t, t_, alpha, beta, gamma) - ut, st = mRNA(tau, u0, s0, alpha, beta, gamma) + ut, st = SplicingDynamics( + alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0] + ).get_solution(tau, stacked=False) ut_, st_ = ut, st ut = ut * o + ut_ * o_ @@ -578,7 +615,9 @@ def compute_divergence( o = (res < 2) * (t < t_) o_ = (res < 2) * (t >= t_) tau, alpha, u0, s0 = vectorize(t, t_, alpha, beta, gamma) - ut, st = mRNA(tau, u0, s0, alpha, beta, gamma) + ut, st = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution( + tau, stacked=False + ) ut_, st_ = ut, st alpha = alpha * o @@ -644,7 +683,9 @@ def curve_dists( num=None, ): if u0_ is None or s0_ is None: - u0_, s0_ = mRNA(t_, 0, 0, alpha, beta, gamma) + u0_, s0_ = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution( + t_, stacked=False + ) x_obs = np.vstack([u, s]).T std_x = np.vstack([std_u / scaling, std_s]).T @@ -654,8 +695,12 @@ def curve_dists( tpoints = np.linspace(0, t_, num=num) tpoints_ = np.linspace(0, t0, num=num)[1:] - curve_t = np.vstack(mRNA(tpoints, 0, 0, alpha, beta, gamma)).T - curve_t_ = np.vstack(mRNA(tpoints_, u0_, s0_, 0, beta, gamma)).T + curve_t = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution( + tpoints + ) + curve_t_ = SplicingDynamics( + alpha=0, beta=beta, gamma=gamma, initial_state=[u0_, s0_] + ).get_solution(tpoints_) # match each curve point to nearest observation dist, dist_ = np.zeros(len(curve_t)), np.zeros(len(curve_t_)) @@ -728,7 +773,7 @@ def __init__( self.recoverable = True try: self.initialize_weights() - except: + except Exception: self.recoverable = False logg.warn(f"Model for {self.gene} could not be instantiated.") @@ -796,7 +841,11 @@ def load_pars(self, adata, gene): self.pval_steady = adata.var["fit_pval_steady"][idx] self.alpha_ = 0 - self.u0_, self.s0_ = mRNA(self.t_, 0, 0, self.alpha, self.beta, self.gamma) + self.u0_, self.s0_ = SplicingDynamics( + alpha=self.alpha, + beta=self.beta, + gamma=self.gamma, + ).get_solution(self.t_, stacked=False) self.pars = [self.alpha, self.beta, self.gamma, self.t_, self.scaling] self.pars = np.array(self.pars)[:, None] @@ -988,7 +1037,9 @@ def get_dists( ) tau, alpha, u0, s0 = vectorize(t, t_, alpha, beta, gamma) - ut, st = mRNA(tau, u0, s0, alpha, beta, gamma) + ut, st = SplicingDynamics( + alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0] + ).get_solution(tau, stacked=False) udiff = np.array(ut - u) / self.std_u * scaling sdiff = np.array(st - s) / self.std_s @@ -1072,9 +1123,13 @@ def get_curve_likelihood(self): kwargs = dict(std_u=self.std_u, std_s=self.std_s, scaling=scaling) dist, dist_ = curve_dists(u, s, alpha, beta, gamma, t_, **kwargs) - l = -0.5 / len(dist) * np.sum(dist) / varx - 0.5 * np.log(2 * np.pi * varx) - l_ = -0.5 / len(dist_) * np.sum(dist_) / varx - 0.5 * np.log(2 * np.pi * varx) - likelihood = np.exp(np.max([l, l_])) + log_likelihood = -0.5 / len(dist) * np.sum(dist) / varx - 0.5 * np.log( + 2 * np.pi * varx + ) + log_likelihood_ = -0.5 / len(dist_) * np.sum(dist_) / varx - 0.5 * np.log( + 2 * np.pi * varx + ) + likelihood = np.exp(np.max([log_likelihood, log_likelihood_])) return likelihood def get_variance(self, **kwargs): @@ -1113,7 +1168,7 @@ def plot_phase( show_assignments=None, **kwargs, ): - from ..plotting.scatter import scatter + from scvelo.plotting.scatter import scatter if np.all([x is None for x in [alpha, beta, gamma, scaling, t_]]): refit_time = False @@ -1164,7 +1219,7 @@ def plot_profile_contour( return_color_scale=False, **kwargs, ): - from ..plotting.utils import update_axes + from scvelo.plotting.utils import update_axes x_var = getattr(self, xkey) y_var = getattr(self, ykey) @@ -1175,14 +1230,13 @@ def plot_profile_contour( assignment_mode = self.assignment_mode self.assignment_mode = None - fp = lambda x, y: self.get_likelihood( - **{xkey: x, ykey: y}, refit_time=refit_time - ) - + # TODO: Check if list comprehension can be used zp = np.zeros((len(x), len(x))) for i, xi in enumerate(x): for j, yi in enumerate(y): - zp[i, j] = fp(xi, yi) + zp[i, j] = self.get_likelihood( + **{xkey: xi, ykey: yi}, refit_time=refit_time + ) log_zp = np.log1p(zp.T) if vmin is None: @@ -1238,7 +1292,7 @@ def plot_profile_hist( vmax=None, show=True, ): - from ..plotting.utils import update_axes + from scvelo.plotting.utils import update_axes x_var = getattr(self, xkey) x = np.linspace(-sight, sight, num=num) * x_var + x_var @@ -1246,10 +1300,10 @@ def plot_profile_hist( assignment_mode = self.assignment_mode self.assignment_mode = None - fp = lambda x: self.get_likelihood(**{xkey: x}, refit_time=True) + # TODO: Check if list comprehension can be used zp = np.zeros((len(x))) for i, xi in enumerate(x): - zp[i] = fp(xi) + zp[i] = self.get_likelihood(**{xkey: xi}, refit_time=True) log_zp = np.log1p(zp.T) if vmin is None: @@ -1373,8 +1427,7 @@ def plot_state_likelihoods( ax=None, **kwargs, ): - from ..plotting.utils import update_axes - from ..plotting.utils import rgb_custom_colormap + from scvelo.plotting.utils import rgb_custom_colormap, update_axes if color_map is None: color_map = rgb_custom_colormap( @@ -1589,6 +1642,7 @@ def get_pval_diff_kinetics(self, orth_beta=None, min_cells=10, **kwargs): ------- p-value """ + if ( "weights_cluster" in kwargs and np.sum(kwargs["weights_cluster"]) < min_cells diff --git a/scvelo/tools/dynamical_model_utils_deprecated.py b/scvelo/tools/dynamical_model_utils_deprecated.py deleted file mode 100644 index 083d7354..00000000 --- a/scvelo/tools/dynamical_model_utils_deprecated.py +++ /dev/null @@ -1,1040 +0,0 @@ -# DEPRECATED - -from .. import settings -from .. import logging as logg -from ..preprocessing.moments import get_connectivities -from .utils import make_dense, make_unique_list, test_bimodality - -import warnings -import matplotlib.pyplot as pl -from matplotlib import rcParams - -import numpy as np - -exp = np.exp - - -def log(x, eps=1e-6): # to avoid invalid values for log. - return np.log(np.clip(x, eps, 1 - eps)) - - -def inv(x): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_inv = 1 / x * (x != 0) - return x_inv - - -def unspliced(tau, u0, alpha, beta): - expu = exp(-beta * tau) - return u0 * expu + alpha / beta * (1 - expu) - - -def spliced(tau, s0, u0, alpha, beta, gamma): - c = (alpha - u0 * beta) * inv(gamma - beta) - expu, exps = exp(-beta * tau), exp(-gamma * tau) - return s0 * exps + alpha / gamma * (1 - exps) + c * (exps - expu) - - -def mRNA(tau, u0, s0, alpha, beta, gamma): - expu, exps = exp(-beta * tau), exp(-gamma * tau) - u = u0 * expu + alpha / beta * (1 - expu) - s = ( - s0 * exps - + alpha / gamma * (1 - exps) - + (alpha - u0 * beta) * inv(gamma - beta) * (exps - expu) - ) - return u, s - - -def vectorize(t, t_, alpha, beta, gamma=None, alpha_=0, u0=0, s0=0, sorted=False): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - o = np.array(t < t_, dtype=int) - tau = t * o + (t - t_) * (1 - o) - - u0_ = unspliced(t_, u0, alpha, beta) - s0_ = spliced(t_, s0, u0, alpha, beta, gamma if gamma is not None else beta / 2) - - # vectorize u0, s0 and alpha - u0 = u0 * o + u0_ * (1 - o) - s0 = s0 * o + s0_ * (1 - o) - alpha = alpha * o + alpha_ * (1 - o) - - if sorted: - idx = np.argsort(t) - tau, alpha, u0, s0 = tau[idx], alpha[idx], u0[idx], s0[idx] - return tau, alpha, u0, s0 - - -def tau_inv(u, s=None, u0=None, s0=None, alpha=None, beta=None, gamma=None): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - inv_u = (gamma >= beta) if gamma is not None else True - inv_us = np.invert(inv_u) - any_invu = np.any(inv_u) or s is None - any_invus = np.any(inv_us) and s is not None - - if any_invus: # tau_inv(u, s) - beta_ = beta * inv(gamma - beta) - xinf = alpha / gamma - beta_ * (alpha / beta) - tau = -1 / gamma * log((s - beta_ * u - xinf) / (s0 - beta_ * u0 - xinf)) - - if any_invu: # tau_inv(u) - uinf = alpha / beta - tau_u = -1 / beta * log((u - uinf) / (u0 - uinf)) - tau = tau_u * inv_u + tau * inv_us if any_invus else tau_u - return tau - - -def find_swichting_time(u, s, tau, o, alpha, beta, gamma, plot=False): - off, on = o == 0, o == 1 - t0_ = np.max(tau[on]) if on.sum() > 0 and np.max(tau[on]) > 0 else np.max(tau) - - if off.sum() > 0: - u_, s_, tau_ = u[off], s[off], tau[off] - - beta_ = beta * inv(gamma - beta) - ceta_ = alpha / gamma - beta_ * alpha / beta - - x = -ceta_ * exp(-gamma * tau_) - y = s_ - beta_ * u_ - - exp_t0_ = (y * x).sum() / (x ** 2).sum() - if -1 < exp_t0_ < 0: - t0_ = -1 / gamma * log(exp_t0_ + 1) - if plot: - pl.scatter(x, y) - return t0_ - - -def fit_alpha(u, s, tau, o, beta, gamma, fit_scaling=False): - off, on = o == 0, o == 1 - if on.sum() > 0 or off.sum() > 0 or tau[on].min() == 0 or tau[off].min() == 0: - alpha = None - else: - tau_on, tau_off = tau[on], tau[off] - - # 'on' state - expu, exps = exp(-beta * tau_on), exp(-gamma * tau_on) - - # 'off' state - t0_ = np.max(tau_on) - expu_, exps_ = exp(-beta * tau_off), exp(-gamma * tau_off) - expu0_, exps0_ = exp(-beta * t0_), exp(-gamma * t0_) - - # from unspliced dynamics - c_beta = 1 / beta * (1 - expu) - c_beta_ = 1 / beta * (1 - expu0_) * expu_ - - # from spliced dynamics - c_gamma = (1 - exps) / gamma + (exps - expu) * inv(gamma - beta) - c_gamma_ = ( - (1 - exps0_) / gamma + (exps0_ - expu0_) * inv(gamma - beta) - ) * exps_ - (1 - expu0_) * (exps_ - expu_) * inv(gamma - beta) - - # concatenating together - c = np.concatenate([c_beta, c_gamma, c_beta_, c_gamma_]).T - x = np.concatenate([u[on], s[on], u[off], s[off]]).T - - alpha = (c * x).sum() / (c ** 2).sum() - - if fit_scaling: # alternatively compute alpha and scaling simultaneously - c = np.concatenate([c_gamma, c_gamma_]).T - x = np.concatenate([s[on], s[off]]).T - alpha = (c * x).sum() / (c ** 2).sum() - - c = np.concatenate([c_beta, c_beta_]).T - x = np.concatenate([u[on], u[off]]).T - scaling = (c * x).sum() / (c ** 2).sum() / alpha # ~ alpha * z / alpha - return alpha, scaling - - return alpha - - -def fit_scaling(u, t, t_, alpha, beta): - tau, alpha, u0, _ = vectorize(t, t_, alpha, beta) - ut = unspliced(tau, u0, alpha, beta) - return (u * ut).sum() / (ut ** 2).sum() - - -def tau_s(s, s0, u0, alpha, beta, gamma, u=None, tau=None, eps=1e-2): - if tau is None: - tau = tau_inv(u, u0=u0, alpha=alpha, beta=beta) if u is not None else 1 - tau_prev, loss, n_iter, max_iter, mixed_states = 1e6, 1e6, 0, 10, np.any(alpha == 0) - b0 = (alpha - beta * u0) * inv(gamma - beta) - g0 = s0 - alpha / gamma + b0 - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - while np.abs(tau - tau_prev).max() > eps and loss > eps and n_iter < max_iter: - tau_prev, n_iter = tau, n_iter + 1 - - expu, exps = b0 * exp(-beta * tau), g0 * exp(-gamma * tau) - f = exps - expu + alpha / gamma # >0 - ft = -gamma * exps + beta * expu # >0 if on else <0 - ftt = gamma ** 2 * exps - beta ** 2 * expu - - a, b, c = ftt / 2, ft, f - s - term = b ** 2 - 4 * a * c - update = (-b + np.sqrt(term)) / (2 * a) - if mixed_states: - update = np.nan_to_num(update) * (alpha > 0) + (-c / b) * (alpha <= 0) - tau = ( - np.nan_to_num(tau_prev + update) * (s != 0) - if np.any(term > 0) - else tau_prev / 10 - ) - loss = np.abs( - alpha / gamma + g0 * exp(-gamma * tau) - b0 * exp(-beta * tau) - s - ).max() - - return np.clip(tau, 0, None) - - -def assign_timepoints_projection( - u, s, alpha, beta, gamma, t0_=None, u0_=None, s0_=None, n_timepoints=300 -): - if t0_ is None: - t0_ = tau_inv(u=u0_, u0=0, alpha=alpha, beta=beta) - if u0_ is None or s0_ is None: - u0_, s0_ = ( - unspliced(t0_, 0, alpha, beta), - spliced(t0_, 0, 0, alpha, beta, gamma), - ) - - tpoints = np.linspace(0, t0_, num=n_timepoints) - tpoints_ = np.linspace( - 0, tau_inv(np.min(u[s > 0]), u0=u0_, alpha=0, beta=beta), num=n_timepoints - )[1:] - - xt = np.vstack( - [unspliced(tpoints, 0, alpha, beta), spliced(tpoints, 0, 0, alpha, beta, gamma)] - ).T - xt_ = np.vstack( - [unspliced(tpoints_, u0_, 0, beta), spliced(tpoints_, s0_, u0_, 0, beta, gamma)] - ).T - x_obs = np.vstack([u, s]).T - - # assign time points (oth. projection onto 'on' and 'off' curve) - tau, o, diff = np.zeros(len(u)), np.zeros(len(u), dtype=int), np.zeros(len(u)) - tau_alt, diff_alt = np.zeros(len(u)), np.zeros(len(u)) - for i, xi in enumerate(x_obs): - diffs, diffs_ = ( - np.linalg.norm((xt - xi), axis=1), - np.linalg.norm((xt_ - xi), axis=1), - ) - idx, idx_ = np.argmin(diffs), np.argmin(diffs_) - - o[i] = np.argmin([diffs_[idx_], diffs[idx]]) - tau[i] = [tpoints_[idx_], tpoints[idx]][o[i]] - diff[i] = [diffs_[idx_], diffs[idx]][o[i]] - - tau_alt[i] = [tpoints_[idx_], tpoints[idx]][1 - o[i]] - diff_alt[i] = [diffs_[idx_], diffs[idx]][1 - o[i]] - - t = tau * o + (t0_ + tau) * (1 - o) - - return t, tau, o - - -"""State-independent derivatives""" - - -def dtau(u, s, alpha, beta, gamma, u0, s0, du0=[0, 0, 0], ds0=[0, 0, 0, 0]): - a, b, g, gb, b0 = alpha, beta, gamma, gamma - beta, beta * inv(gamma - beta) - - cu = s - a / g - b0 * (u - a / b) - c0 = s0 - a / g - b0 * (u0 - a / b) - cu += cu == 0 - c0 += c0 == 0 - cu_, c0_ = 1 / cu, 1 / c0 - - dtau_a = b0 / g * (c0_ - cu_) + 1 / g * c0_ * (ds0[0] - b0 * du0[0]) - dtau_b = 1 / gb ** 2 * ((u - a / g) * cu_ - (u0 - a / g) * c0_) - - dtau_c = -a / g * (1 / g ** 2 - 1 / gb ** 2) * (cu_ - c0_) - b0 / g / gb * ( - u * cu_ - u0 * c0_ - ) # + 1/g**2 * np.log(cu/c0) - - return dtau_a, dtau_b, dtau_c - - -def du(tau, alpha, beta, u0=0, du0=[0, 0, 0], dtau=[0, 0, 0]): - # du0 is the derivative du0 / d(alpha, beta, tau) - expu, cb = exp(-beta * tau), alpha / beta - du_a = ( - du0[0] * expu + 1.0 / beta * (1 - expu) + (alpha - beta * u0) * dtau[0] * expu - ) - du_b = ( - du0[1] * expu - - cb / beta * (1 - expu) - + (cb - u0) * tau * expu - + (alpha - beta * u0) * dtau[1] * expu - ) - return du_a, du_b - - -def ds( - tau, alpha, beta, gamma, u0=0, s0=0, du0=[0, 0, 0], ds0=[0, 0, 0, 0], dtau=[0, 0, 0] -): - # ds0 is the derivative ds0 / d(alpha, beta, gamma, tau) - expu, exps = exp(-beta * tau), exp(-gamma * tau) - expus = exps - expu - - cbu = (alpha - beta * u0) * inv(gamma - beta) - ccu = (alpha - gamma * u0) * inv(gamma - beta) - ccs = alpha / gamma - s0 - cbu - - ds_a = ( - ds0[0] * exps - + 1.0 / gamma * (1 - exps) - + 1 * inv(gamma - beta) * (1 - beta * du0[0]) * expus - + (ccs * gamma * exps + cbu * beta * expu) * dtau[0] - ) - ds_b = ( - ds0[1] * exps - + cbu * tau * expu - + 1 * inv(gamma - beta) * (ccu - beta * du0[1]) * expus - + (ccs * gamma * exps + cbu * beta * expu) * dtau[1] - ) - ds_c = ( - ds0[2] * exps - + ccs * tau * exps - - alpha / gamma ** 2 * (1 - exps) - - cbu * inv(gamma - beta) * expus - + (ccs * gamma * exps + cbu * beta * expu) * dtau[2] - ) - - return ds_a, ds_b, ds_c - - -def derivatives( - u, s, t, t0_, alpha, beta, gamma, scaling=1, alpha_=0, u0=0, s0=0, weights=None -): - o = np.array(t < t0_, dtype=int) - - du0 = np.array(du(t0_, alpha, beta, u0))[:, None] * (1 - o)[None, :] - ds0 = np.array(ds(t0_, alpha, beta, gamma, u0, s0))[:, None] * (1 - o)[None, :] - - tau, alpha, u0, s0 = vectorize(t, t0_, alpha, beta, gamma, alpha_, u0, s0) - dt = np.array(dtau(u, s, alpha, beta, gamma, u0, s0, du0, ds0)) - - # state-dependent derivatives: - du_a, du_b = du(tau, alpha, beta, u0, du0, dt) - du_a, du_b = du_a * scaling, du_b * scaling - - ds_a, ds_b, ds_c = ds(tau, alpha, beta, gamma, u0, s0, du0, ds0, dt) - - # evaluate derivative of likelihood: - ut, st = mRNA(tau, u0, s0, alpha, beta, gamma) - - # udiff = np.array(ut * scaling - u) - udiff = np.array(ut - u / scaling) - sdiff = np.array(st - s) - - if weights is not None: - udiff = np.multiply(udiff, weights) - sdiff = np.multiply(sdiff, weights) - - dl_a = (du_a * (1 - o)).dot(udiff) + (ds_a * (1 - o)).dot(sdiff) - dl_a_ = (du_a * o).dot(udiff) + (ds_a * o).dot(sdiff) - - dl_b = du_b.dot(udiff) + ds_b.dot(sdiff) - dl_c = ds_c.dot(sdiff) - - dl_tau, dl_t0_ = None, None - return dl_a, dl_b, dl_c, dl_a_, dl_tau, dl_t0_ - - -class BaseDynamics: - def __init__(self, adata=None, u=None, s=None): - self.s, self.u = s, u - - zeros, zeros3 = np.zeros(adata.n_obs), np.zeros((3, 1)) - self.u0, self.s0, self.u0_, self.s0_, self.t_, self.scaling = ( - None, - None, - None, - None, - None, - None, - ) - self.t, self.tau, self.o, self.weights = zeros, zeros, zeros, zeros - - self.alpha, self.beta, self.gamma, self.alpha_, self.pars = ( - None, - None, - None, - None, - None, - ) - self.dpars, self.m_dpars, self.v_dpars, self.loss = zeros3, zeros3, zeros3, [] - - def uniform_weighting(self, n_regions=5, perc=95): # deprecated - from numpy import union1d as union - from numpy import intersect1d as intersect - - u, s = self.u, self.s - u_b = np.linspace(0, np.percentile(u, perc), n_regions) - s_b = np.linspace(0, np.percentile(s, perc), n_regions) - - regions, weights = {}, np.ones(len(u)) - for i in range(n_regions): - if i == 0: - region = intersect(np.where(u < u_b[i + 1]), np.where(s < s_b[i + 1])) - elif i < n_regions - 1: - lower_cut = union(np.where(u > u_b[i]), np.where(s > s_b[i])) - upper_cut = intersect( - np.where(u < u_b[i + 1]), np.where(s < s_b[i + 1]) - ) - region = intersect(lower_cut, upper_cut) - else: - region = union( - np.where(u > u_b[i]), np.where(s > s_b[i]) - ) # lower_cut for last region - regions[i] = region - if len(region) > 0: - weights[region] = n_regions / len(region) - # set weights accordingly such that each region has an equal overall contribution. - self.weights = weights * len(u) / np.sum(weights) - self.u_b, self.s_b = u_b, s_b - - def plot_regions(self): - u, s, ut, st = self.u, self.s, self.ut, self.st - u_b, s_b = self.u_b, self.s_b - - pl.figure(dpi=100) - pl.scatter(s, u, color="grey") - pl.xlim(0) - pl.ylim(0) - pl.xlabel("spliced") - pl.ylabel("unspliced") - - for i in range(len(s_b)): - pl.plot([s_b[i], s_b[i], 0], [0, u_b[i], u_b[i]]) - - def plot_derivatives(self): - u, s = self.u, self.s - alpha, beta, gamma = self.alpha, self.beta, self.gamma - t, tau, o, t_ = self.t, self.tau, self.o, self.t_ - - du0 = np.array(du(t_, alpha, beta))[:, None] * (1 - o)[None, :] - ds0 = np.array(ds(t_, alpha, beta, gamma))[:, None] * (1 - o)[None, :] - - tau, alpha, u0, s0 = vectorize(t, t_, alpha, beta, gamma) - dt = np.array(dtau(u, s, alpha, beta, gamma, u0, s0)) - - du_a, du_b = du(tau, alpha, beta, u0=u0, du0=du0, dtau=dt) - ds_a, ds_b, ds_c = ds( - tau, alpha, beta, gamma, u0=u0, s0=s0, du0=du0, ds0=ds0, dtau=dt - ) - - idx = np.argsort(t) - t = np.sort(t) - - pl.plot(t, du_a[idx], label=r"$\partial u / \partial\alpha$") - pl.plot(t, 0.2 * du_b[idx], label=r"$\partial u / \partial \beta$") - pl.plot(t, ds_a[idx], label=r"$\partial s / \partial \alpha$") - pl.plot(t, ds_b[idx], label=r"$\partial s / \partial \beta$") - pl.plot(t, 0.2 * ds_c[idx], label=r"$\partial s / \partial \gamma$") - - pl.legend() - pl.xlabel("t") - - -class DynamicsRecovery(BaseDynamics): - def __init__( - self, - adata=None, - gene=None, - u=None, - s=None, - use_raw=False, - load_pars=None, - fit_scaling=False, - fit_time=True, - fit_switching=True, - fit_steady_states=True, - fit_alpha=True, - fit_connected_states=True, - ): - super(DynamicsRecovery, self).__init__(adata.n_obs) - - _layers = adata[:, gene].layers - self.gene = gene - self.use_raw = use_raw = use_raw or "Ms" not in _layers.keys() - - # extract actual data - if u is None or s is None: - u = ( - make_dense(_layers["unspliced"]) - if use_raw - else make_dense(_layers["Mu"]) - ) - s = make_dense(_layers["spliced"]) if use_raw else make_dense(_layers["Ms"]) - self.s, self.u = s, u - - # set weights for fitting (exclude dropouts and extreme outliers) - nonzero = np.ravel(s > 0) & np.ravel(u > 0) - s_filter = np.ravel(s < np.percentile(s[nonzero], 98)) - u_filter = np.ravel(u < np.percentile(u[nonzero], 98)) - - self.weights = s_filter & u_filter & nonzero - self.fit_scaling = fit_scaling - self.fit_time = fit_time - self.fit_alpha = fit_alpha - self.fit_switching = fit_switching - self.fit_steady_states = fit_steady_states - self.connectivities = ( - get_connectivities(adata) - if fit_connected_states is True - else fit_connected_states - ) - - if load_pars and "fit_alpha" in adata.var.keys(): - self.load_pars(adata, gene) - else: - self.initialize() - - def initialize(self): - # set weights - u, s, w = self.u * 1.0, self.s * 1.0, self.weights - u_w, s_w, perc = u[w], s[w], 98 - - # initialize scaling - self.std_u, self.std_s = np.std(u_w), np.std(s_w) - scaling = ( - self.std_u / self.std_s - if isinstance(self.fit_scaling, bool) - else self.fit_scaling - ) - u, u_w = u / scaling, u_w / scaling - - # initialize beta and gamma from extreme quantiles of s - if True: - weights_s = s_w >= np.percentile(s_w, perc, axis=0) - else: - us_norm = s_w / np.clip(np.max(s_w, axis=0), 1e-3, None) + u_w / np.clip( - np.max(u_w, axis=0), 1e-3, None - ) - weights_s = us_norm >= np.percentile(us_norm, perc, axis=0) - - beta, gamma = 1, linreg(convolve(u_w, weights_s), convolve(s_w, weights_s)) - - u_inf, s_inf = u_w[weights_s].mean(), s_w[weights_s].mean() - u0_, s0_ = u_inf, s_inf - alpha = np.mean( - [s_inf * gamma, u_inf * beta] - ) # np.mean([s0_ * gamma, u0_ * beta]) - - # initialize switching from u quantiles and alpha from s quantiles - tstat_u, pval_u, means_u = test_bimodality(u_w, kde=True) - tstat_s, pval_s, means_s = test_bimodality(s_w, kde=True) - self.pval_steady = max(pval_u, pval_s) - self.u_steady = means_u[1] - self.s_steady = means_s[1] - - if self.pval_steady < 0.1: - u_inf = np.mean([u_inf, self.u_steady]) - s_inf = np.mean([s_inf, self.s_steady]) - alpha = s_inf * gamma - beta = alpha / u_inf - - weights_u = u_w >= np.percentile(u_w, perc, axis=0) - u0_, s0_ = u_w[weights_u].mean(), s_w[weights_u].mean() - - # alpha, beta, gamma = np.array([alpha, beta, gamma]) * scaling - t_ = tau_inv(u0_, s0_, 0, 0, alpha, beta, gamma) - - # update object with initialized vars - alpha_, u0, s0 = 0, 0, 0 - self.alpha, self.beta, self.gamma, self.alpha_, self.scaling = ( - alpha, - beta, - gamma, - alpha_, - scaling, - ) - self.u0, self.s0, self.u0_, self.s0_, self.t_ = u0, s0, u0_, s0_, t_ - self.pars = np.array([alpha, beta, gamma, self.t_, self.scaling])[:, None] - - # initialize time point assignment - self.t, self.tau, self.o = self.get_time_assignment() - self.loss = [self.get_loss()] - - self.update_scaling(sight=0.5) - self.update_scaling(sight=0.1) - - def load_pars(self, adata, gene): - idx = adata.var_names.get_loc(gene) if isinstance(gene, str) else gene - self.alpha = adata.var["fit_alpha"][idx] - self.beta = adata.var["fit_beta"][idx] - self.gamma = adata.var["fit_gamma"][idx] - self.scaling = adata.var["fit_scaling"][idx] - self.t_ = adata.var["fit_t_"][idx] - self.t = adata.layers["fit_t"][:, idx] - self.o = self.t < self.t_ - self.tau = self.t * self.o + (self.t - self.t_) * (1 - self.o) - self.pars = np.array( - [self.alpha, self.beta, self.gamma, self.t_, self.scaling] - )[:, None] - - self.u0, self.s0, self.alpha_ = 0, 0, 0 - self.u0_ = unspliced(self.t_, self.u0, self.alpha, self.beta) - self.s0_ = spliced(self.t_, self.u0, self.s0, self.alpha, self.beta, self.gamma) - - self.update_state_dependent() - - def fit( - self, - max_iter=100, - r=None, - method=None, - clip_loss=None, - assignment_mode=None, - min_loss=True, - ): - updated, idx_update = True, np.clip(int(max_iter / 10), 1, None) - - for i in range(max_iter): - self.update_vars(r=r, method=method, clip_loss=clip_loss) - if updated or (i % idx_update == 1) or i == max_iter - 1: - updated = self.update_state_dependent() - if i > 10 and (i % idx_update == 1): - loss_prev, loss = np.max(self.loss[-10:]), self.loss[-1] - if loss_prev - loss < loss_prev * 1e-3: - updated = self.shuffle_pars() - if not updated: - break - - if self.fit_switching: - self.update_switching() - if min_loss: - alpha, beta, gamma, t_, scaling = self.pars[:, np.argmin(self.loss)] - up = self.update_loss( - None, t_, alpha, beta, gamma, scaling, reassign_time=True - ) - self.t, self.tau, self.o = self.get_time_assignment( - assignment_mode=assignment_mode - ) - - def update_state_dependent(self): - updated = False - if self.fit_alpha: - updated = self.update_alpha() | updated - if self.fit_switching: - updated = self.update_switching() | updated - return updated - - def update_scaling(self, sight=0.5): # fit scaling and update if improved - z_vals = self.scaling + np.linspace(-1, 1, num=5) * self.scaling * sight - for z in z_vals: - u0_ = self.u0_ * self.scaling / z - beta = self.beta / self.scaling * z - self.update_loss( - scaling=z, beta=beta, u0_=u0_, s0_=self.s0_, reassign_time=True - ) - - def update_alpha(self): # fit alpha (generalized lin.reg), update if improved - updated = False - alpha = self.get_optimal_alpha() - gamma = self.gamma - - alpha_vals = alpha + np.linspace(-1, 1, num=5) * alpha / 30 - gamma_vals = gamma + np.linspace(-1, 1, num=4) * gamma / 30 - - for alpha in alpha_vals: - for gamma in gamma_vals: - updated = ( - self.update_loss(alpha=alpha, gamma=gamma, reassign_time=True) - | updated - ) - return updated - - def update_switching( - self, - ): # find optimal switching (generalized lin.reg) & assign timepoints/states (explicit) - updated = False - # t_ = self.t_ - t_ = self.get_optimal_switch() - t_vals = t_ + np.linspace(-1, 1, num=3) * t_ / 5 - for t_ in t_vals: - updated = self.update_loss(t_=t_, reassign_time=True) | updated - - if True: # self.pval_steady > .1: - z_vals = 1 + np.linspace(-1, 1, num=4) / 5 - for z in z_vals: - beta, gamma = self.beta * z, self.gamma * z - t, tau, o = self.get_time_assignment(beta=beta, gamma=gamma) - t_ = np.max(t * o) - if t_ > 0: - update = self.update_loss( - t_=np.max(t * o), beta=beta, gamma=gamma, reassign_time=True - ) - updated |= update - if update: - self.update_loss( - t_=self.get_optimal_switch(), reassign_time=True - ) - return updated - - def update_vars(self, r=None, method=None, clip_loss=None): - if r is None: - r = 1e-2 if method == "adam" else 1e-5 - if clip_loss is None: - clip_loss = method != "adam" - # if self.weights is None: - # self.uniform_weighting(n_regions=5, perc=95) - t, t_, alpha, beta, gamma, scaling = ( - self.t, - self.t_, - self.alpha, - self.beta, - self.gamma, - self.scaling, - ) - dalpha, dbeta, dgamma, dalpha_, dtau, dt_ = derivatives( - self.u, self.s, t, t_, alpha, beta, gamma, scaling - ) - - if method == "adam": - b1, b2, eps = 0.9, 0.999, 1e-8 - - # update 1st and 2nd order gradient moments - dpars = np.array([dalpha, dbeta, dgamma]) - m_dpars = b1 * self.m_dpars[:, -1] + (1 - b1) * dpars - v_dpars = b2 * self.v_dpars[:, -1] + (1 - b2) * dpars ** 2 - - self.dpars = np.c_[self.dpars, dpars] - self.m_dpars = np.c_[self.m_dpars, m_dpars] - self.v_dpars = np.c_[self.v_dpars, v_dpars] - - # correct for bias - t = len(self.m_dpars[0]) - m_dpars /= 1 - b1 ** t - v_dpars /= 1 - b2 ** t - - # Adam parameter update - # Parameters are restricted to be positive - n_alpha = alpha - r * m_dpars[0] / (np.sqrt(v_dpars[0]) + eps) - alpha = n_alpha if n_alpha > 0 else alpha - n_beta = beta - r * m_dpars[1] / (np.sqrt(v_dpars[1]) + eps) - beta = n_beta if n_beta > 0 else beta - n_gamma = gamma - r * m_dpars[2] / (np.sqrt(v_dpars[2]) + eps) - gamma = n_gamma if n_gamma > 0 else gamma - - else: - # Parameters are restricted to be positive - n_alpha = alpha - r * dalpha - alpha = n_alpha if n_alpha > 0 else alpha - n_beta = beta - r * dbeta - beta = n_beta if n_beta > 0 else beta - n_gamma = gamma - r * dgamma - gamma = n_gamma if n_gamma > 0 else gamma - - # tau -= r * dtau - # t_ -= r * dt_ - # t_ = np.max(self.tau * self.o) - # t = tau * self.o + (tau + t_) * (1 - self.o) - - updated_vars = self.update_loss( - alpha=alpha, - beta=beta, - gamma=gamma, - clip_loss=clip_loss, - reassign_time=False, - ) - - def update_loss( - self, - t=None, - t_=None, - alpha=None, - beta=None, - gamma=None, - scaling=None, - u0_=None, - s0_=None, - reassign_time=False, - clip_loss=True, - report=False, - ): - vals = [t_, alpha, beta, gamma, scaling] - vals_prev = [self.t_, self.alpha, self.beta, self.gamma, self.scaling] - vals_name = ["t_", "alpha", "beta", "gamma", "scaling"] - new_vals, new_vals_prev, new_vals_name = [], [], [] - loss_prev = self.loss[-1] if len(self.loss) > 0 else 1e6 - - for val, val_prev, val_name in zip(vals, vals_prev, vals_name): - if val is not None: - new_vals.append(val) - new_vals_prev.append(val_prev) - new_vals_name.append(val_name) - if t_ is None: - t_ = ( - tau_inv( - self.u0_ if u0_ is None else u0_, - self.s0_ if s0_ is None else s0_, - 0, - 0, - self.alpha if alpha is None else alpha, - self.beta if beta is None else beta, - self.gamma if gamma is None else gamma, - ) - if u0_ is not None - else self.t_ - ) - - t, t_, alpha, beta, gamma, scaling = self.get_vals( - t, t_, alpha, beta, gamma, scaling - ) - - if reassign_time: - # t_ = self.get_optimal_switch(alpha, beta, gamma) if t_ is None else t_ - t, tau, o = self.get_time_assignment(t_, alpha, beta, gamma, scaling) - - loss = self.get_loss(t, t_, alpha, beta, gamma, scaling) - perform_update = not clip_loss or loss < loss_prev - - if perform_update: - if ( - len(self.loss) > 0 and loss_prev - loss > loss_prev * 0.01 and report - ): # improvement by at least 1% - print( - "Update:", - " ".join(map(str, new_vals_name)), - " ".join(map(str, np.round(new_vals_prev, 2))), - "-->", - " ".join(map(str, np.round(new_vals, 2))), - ) - - print(" loss:", np.round(loss_prev, 2), "-->", np.round(loss, 2)) - - if "t_" in new_vals_name or reassign_time: - if reassign_time: - self.t = t - self.t_ = t_ - self.o = o = np.array(self.t < t_, dtype=bool) - self.tau = self.t * o + (self.t - t_) * (1 - o) - - if u0_ is not None: - self.u0_ = u0_ - self.s0_ = s0_ - - if "alpha" in new_vals_name: - self.alpha = alpha - if "beta" in new_vals_name: - self.beta = beta - if "gamma" in new_vals_name: - self.gamma = gamma - if "scaling" in new_vals_name: - self.scaling = scaling - # self.rescale_invariant() - - self.pars = np.c_[ - self.pars, - np.array([self.alpha, self.beta, self.gamma, self.t_, self.scaling])[ - :, None - ], - ] - self.loss.append(loss if perform_update else loss_prev) - - return perform_update - - def rescale_invariant(self, z=None): - z = self.scaling / self.std_u * self.std_s if z is None else z - self.alpha, self.beta, self.gamma = ( - np.array([self.alpha, self.beta, self.gamma]) * z - ) - self.t, self.tau, self.t_ = self.t / z, self.tau / z, self.t_ / z - - def shuffle_pars(self, alpha_sight=[-0.5, 0.5], gamma_sight=[-0.5, 0.5], num=5): - alpha_vals = ( - np.linspace(alpha_sight[0], alpha_sight[1], num=num) * self.alpha - + self.alpha - ) - gamma_vals = ( - np.linspace(gamma_sight[0], gamma_sight[1], num=num) * self.gamma - + self.gamma - ) - - x, y = alpha_vals, gamma_vals - f = lambda x, y: self.get_loss(alpha=x, gamma=y, reassign_time=self.fit_time) - z = np.zeros((len(x), len(x))) - - for i, xi in enumerate(x): - for j, yi in enumerate(y): - z[i, j] = f(xi, yi) - ix, iy = np.unravel_index(z.argmin(), z.shape) - return self.update_loss(alpha=x[ix], gamma=y[ix], reassign_time=self.fit_time) - - -def read_pars(adata, pars_names=["alpha", "beta", "gamma", "t_", "scaling"], key="fit"): - pars = [] - for name in pars_names: - pkey = f"{key}_{name}" - par = ( - adata.var[pkey].values - if pkey in adata.var.keys() - else np.zeros(adata.n_vars) * np.nan - ) - pars.append(par) - return pars - - -def write_pars( - adata, pars, pars_names=["alpha", "beta", "gamma", "t_", "scaling"], add_key="fit" -): - for i, name in enumerate(pars_names): - adata.var[f"{add_key}_{name}"] = pars[i] - - -def recover_dynamics_deprecated( - data, - var_names="velocity_genes", - max_iter=10, - learning_rate=None, - t_max=None, - use_raw=False, - fit_scaling=True, - fit_time=True, - fit_switching=True, - fit_steady_states=True, - fit_alpha=True, - fit_connected_states=True, - min_loss=True, - assignment_mode=None, - load_pars=None, - add_key="fit", - method="adam", - return_model=True, - plot_results=False, - copy=False, - **kwargs, -): - """Estimates velocities in a gene-specific manner - - Arguments - --------- - data: :class:`~anndata.AnnData` - Annotated data matrix. - - Returns - ------- - Returns or updates `adata` - """ - adata = data.copy() if copy else data - logg.info("recovering dynamics", r=True) - - if isinstance(var_names, str) and var_names not in adata.var_names: - var_names = ( - adata.var_names[adata.var[var_names] == True] - if "genes" in var_names and var_names in adata.var.keys() - else adata.var_names - if "all" in var_names or "genes" in var_names - else var_names - ) - var_names = make_unique_list(var_names, allow_array=True) - var_names = [name for name in var_names if name in adata.var_names] - if len(var_names) == 0: - raise ValueError("Variable name not found in var keys.") - - if fit_connected_states: - fit_connected_states = get_connectivities(adata) - - alpha, beta, gamma, t_, scaling = read_pars(adata) - idx = [] - L, P, T = ( - [], - [], - adata.layers["fit_t"] - if "fit_t" in adata.layers.keys() - else np.zeros(adata.shape) * np.nan, - ) - - progress = logg.ProgressReporter(len(var_names)) - for i, gene in enumerate(var_names): - dm = DynamicsRecovery( - adata, - gene, - use_raw=use_raw, - load_pars=load_pars, - fit_time=fit_time, - fit_alpha=fit_alpha, - fit_switching=fit_switching, - fit_scaling=fit_scaling, - fit_steady_states=fit_steady_states, - fit_connected_states=fit_connected_states, - ) - if max_iter > 1: - dm.fit( - max_iter, - learning_rate, - assignment_mode=assignment_mode, - min_loss=min_loss, - method=method, - **kwargs, - ) - - ix = adata.var_names.get_loc(gene) - idx.append(ix) - - alpha[ix], beta[ix], gamma[ix], t_[ix], scaling[ix] = dm.pars[:, -1] - T[:, ix] = dm.t - L.append(dm.loss) - if plot_results and i < 4: - P.append(dm.pars) - - progress.update() - progress.finish() - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - T_max = np.nanpercentile(T, 95, axis=0) - np.nanpercentile(T, 5, axis=0) - m = t_max / T_max if t_max is not None else np.ones(adata.n_vars) - alpha, beta, gamma, T, t_ = alpha / m, beta / m, gamma / m, T * m, t_ * m - - write_pars(adata, [alpha, beta, gamma, t_, scaling]) - adata.layers["fit_t"] = T - - cur_len = adata.varm["loss"].shape[1] if "loss" in adata.varm.keys() else 2 - max_len = max(np.max([len(l) for l in L]), cur_len) - loss = np.ones((adata.n_vars, max_len)) * np.nan - - if "loss" in adata.varm.keys(): - loss[:, :cur_len] = adata.varm["loss"] - - loss[idx] = np.vstack( - [np.concatenate([l, np.ones(max_len - len(l)) * np.nan]) for l in L] - ) - adata.varm["loss"] = loss - - logg.info(" finished", time=True, end=" " if settings.verbosity > 2 else "\n") - logg.hint( - "added \n" - f" '{add_key}_pars', fitted parameters for splicing dynamics (adata.var)" - ) - - if plot_results: # Plot Parameter Stats - n_rows, n_cols = len(var_names[:4]), 6 - figsize = [2 * n_cols, 1.5 * n_rows] # rcParams['figure.figsize'] - fontsize = rcParams["font.size"] - fig, axes = pl.subplots(nrows=n_rows, ncols=6, figsize=figsize) - pl.subplots_adjust(wspace=0.7, hspace=0.5) - for i, gene in enumerate(var_names[:4]): - P[i] *= np.array( - [1 / m[idx[i]], 1 / m[idx[i]], 1 / m[idx[i]], m[idx[i]], 1] - )[:, None] - ax = axes[i] if n_rows > 1 else axes - for j, pij in enumerate(P[i]): - ax[j].plot(pij) - ax[len(P[i])].plot(L[i]) - if i == 0: - for j, name in enumerate( - ["alpha", "beta", "gamma", "t_", "scaling", "loss"] - ): - ax[j].set_title(name, fontsize=fontsize) - - return dm if return_model else adata if copy else None diff --git a/scvelo/tools/optimization.py b/scvelo/tools/optimization.py index b178fb97..86bcf72e 100644 --- a/scvelo/tools/optimization.py +++ b/scvelo/tools/optimization.py @@ -1,8 +1,11 @@ -from .utils import sum_obs, prod_sum_obs, make_dense +import warnings + +import numpy as np from scipy.optimize import minimize from scipy.sparse import csr_matrix, issparse -import numpy as np -import warnings + +from scvelo.core import prod_sum, sum +from .utils import make_dense def get_weight(x, y=None, perc=95): @@ -22,6 +25,14 @@ def get_weight(x, y=None, perc=95): def leastsq_NxN(x, y, fit_offset=False, perc=None, constraint_positive_offset=True): """Solves least squares X*b=Y for b.""" + + warnings.warn( + "`leastsq_NxN` is deprecated since scVelo v0.2.4 and will be removed in a " + "future version. Please use `LinearRegression` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) + if perc is not None: if not fit_offset and isinstance(perc, (list, tuple)): perc = perc[1] @@ -32,13 +43,13 @@ def leastsq_NxN(x, y, fit_offset=False, perc=None, constraint_positive_offset=Tr with warnings.catch_warnings(): warnings.simplefilter("ignore") - xx_ = prod_sum_obs(x, x) - xy_ = prod_sum_obs(x, y) + xx_ = prod_sum(x, x, axis=0) + xy_ = prod_sum(x, y, axis=0) if fit_offset: - n_obs = x.shape[0] if weights is None else sum_obs(weights) - x_ = sum_obs(x) / n_obs - y_ = sum_obs(y) / n_obs + n_obs = x.shape[0] if weights is None else sum(weights, axis=0) + x_ = sum(x, axis=0) / n_obs + y_ = sum(y, axis=0) / n_obs gamma = (xy_ / n_obs - x_ * y_) / (xx_ / n_obs - x_ ** 2) offset = y_ - gamma * x_ diff --git a/scvelo/tools/paga.py b/scvelo/tools/paga.py index 81e52530..dbb7edb0 100644 --- a/scvelo/tools/paga.py +++ b/scvelo/tools/paga.py @@ -1,16 +1,17 @@ -# This is adapted from https://github.com/theislab/paga -from .. import settings -from .. import logging as logg -from .utils import strings_to_categoricals -from .velocity_graph import vals_to_csr -from .velocity_pseudotime import velocity_pseudotime -from .rank_velocity_genes import velocity_clusters import numpy as np import pandas as pd from scipy.sparse import csr_matrix -from pandas.api.types import is_categorical + from scanpy.tools._paga import PAGA +# This is adapted from https://github.com/theislab/paga +from scvelo import logging as logg +from scvelo import settings +from .rank_velocity_genes import velocity_clusters +from .utils import strings_to_categoricals +from .velocity_graph import vals_to_csr +from .velocity_pseudotime import velocity_pseudotime + def get_igraph_from_adjacency(adjacency, directed=None): """Get igraph graph from adjacency matrix.""" @@ -25,7 +26,7 @@ def get_igraph_from_adjacency(adjacency, directed=None): g.add_edges(list(zip(sources, targets))) try: g.es["weight"] = weights - except: + except Exception: pass if g.vcount() != adjacency.shape[0]: logg.warn( @@ -55,7 +56,9 @@ def get_sparse_from_igraph(graph, weight_attr=None): def set_row_csr(csr, rows, value=0): """Set all nonzero elements to the given value. Useful to set to 0 mostly.""" for row in rows: - csr.data[csr.indptr[row] : csr.indptr[row + 1]] = value + start = csr.indptr[row] + end = csr.indptr[row + 1] + csr.data[start:end] = value if value == 0: csr.eliminate_zeros() @@ -218,27 +221,23 @@ def paga( copy : `bool`, optional (default: `False`) Copy `adata` before computation and return a copy. Otherwise, perform computation inplace and return `None`. + Returns ------- - **connectivities** : (adata.uns['connectivities']) + connectivities: `.uns` The full adjacency matrix of the abstracted graph, weights correspond to confidence in the connectivities of partitions. - **connectivities_tree** : (adata.uns['connectivities_tree']) + connectivities_tree: `.uns` The adjacency matrix of the tree-like subgraph that best explains the topology. - **transitions_confidence** : (adata.uns['transitions_confidence']) + transitions_confidence: `.uns` The adjacency matrix of the abstracted directed graph, weights correspond to confidence in the transitions between partitions. """ + if "neighbors" not in adata.uns: raise ValueError( "You need to run `pp.neighbors` first to compute a neighborhood graph." ) - try: - import igraph - except ImportError: - raise ImportError( - "To run paga, you need to install `pip install python-igraph`" - ) adata = adata.copy() if copy else adata strings_to_categoricals(adata) @@ -260,7 +259,7 @@ def paga( priors = [p for p in [use_time_prior, root_key, end_key] if p in adata.obs.keys()] logg.info( - f"running PAGA", + "running PAGA", f"using priors: {priors}" if len(priors) > 0 else "", r=True, ) diff --git a/scvelo/tools/rank_velocity_genes.py b/scvelo/tools/rank_velocity_genes.py index 26741e86..324fd9cd 100644 --- a/scvelo/tools/rank_velocity_genes.py +++ b/scvelo/tools/rank_velocity_genes.py @@ -1,11 +1,11 @@ -from .. import settings -from .. import logging as logg +import numpy as np +from scipy.sparse import issparse + +from scvelo import logging as logg +from scvelo import settings from .utils import strings_to_categoricals, vcorrcoef from .velocity_pseudotime import velocity_pseudotime -from scipy.sparse import issparse -import numpy as np - def get_mean_var(X, ignore_zeros=False, perc=None): data = X.data if issparse(X) else X @@ -102,10 +102,10 @@ def velocity_clusters( Returns ------- - Returns or updates `data` with the attributes velocity_clusters : `.obs` Clusters obtained from applying louvain modularity on velocity expression. - """ + """ # noqa E501 + adata = data.copy() if copy else data logg.info("computing velocity clusters", r=True) @@ -133,7 +133,7 @@ def velocity_clusters( if "fit_likelihood" in adata.var.keys() and min_likelihood is not None: tmp_filter &= adata.var["fit_likelihood"] > min_likelihood - from .. import AnnData + from anndata import AnnData vdata = AnnData(adata.layers[vkey][:, tmp_filter]) vdata.obs = adata.obs.copy() @@ -217,7 +217,9 @@ def rank_velocity_genes( .. code:: python scv.tl.rank_velocity_genes(adata, groupby='clusters') - scv.pl.scatter(adata, basis=adata.uns['rank_velocity_genes']['names']['Beta'][:3]) + scv.pl.scatter( + adata, basis=adata.uns['rank_velocity_genes']['names']['Beta'][:3] + ) pd.DataFrame(adata.uns['rank_velocity_genes']['names']).head() .. image:: https://user-images.githubusercontent.com/31883718/69626017-11c47980-1048-11ea-89f4-df3769df5ad5.png @@ -255,13 +257,13 @@ def rank_velocity_genes( Returns ------- - Returns or updates `data` with the attributes rank_velocity_genes : `.uns` Structured array to be indexed by group id storing the gene names. Ordered according to scores. velocity_score : `.var` Storing the score for each gene for each group. Ordered according to scores. - """ + """ # noqa E501 + adata = data.copy() if copy else data if groupby is None or groupby == "velocity_clusters": @@ -314,9 +316,9 @@ def rank_velocity_genes( tmp_filter &= dispersions > min_dispersion if "fit_likelihood" in adata.var.keys(): - l = adata.var["fit_likelihood"] + fit_likelihood = adata.var["fit_likelihood"] min_likelihood = 0.1 if min_likelihood is None else min_likelihood - tmp_filter &= l > min_likelihood + tmp_filter &= fit_likelihood > min_likelihood X = adata[:, tmp_filter].layers[vkey] groups, groups_masks = select_groups(adata, key=groupby) @@ -354,7 +356,7 @@ def rank_velocity_genes( all_names = rankings_gene_names.T.flatten() all_scores = rankings_gene_scores.T.flatten() - vscore = np.zeros(adata.n_vars, dtype=np.int) + vscore = np.zeros(adata.n_vars, dtype=int) for i, name in enumerate(adata.var_names): if name in all_names: vscore[i] = all_scores[np.where(name == all_names)[0][0]] diff --git a/scvelo/tools/run.py b/scvelo/tools/run.py index d4b3d8d0..b807b047 100644 --- a/scvelo/tools/run.py +++ b/scvelo/tools/run.py @@ -1,5 +1,5 @@ -from ..preprocessing import filter_and_normalize, moments -from . import velocity, velocity_graph, velocity_embedding +from scvelo.preprocessing import filter_and_normalize, moments +from . import velocity, velocity_embedding, velocity_graph def run_all( @@ -38,7 +38,8 @@ def run_all( def convert_to_adata(vlm, basis=None): from collections import OrderedDict - from .. import AnnData + + from anndata import AnnData X = ( vlm.S_norm.T @@ -90,10 +91,11 @@ def convert_to_adata(vlm, basis=None): def convert_to_loom(adata, basis=None): - from scipy.sparse import issparse - import numpy as np import velocyto + import numpy as np + from scipy.sparse import issparse + class VelocytoLoom(velocyto.VelocytoLoom): def __init__(self, adata, basis=None): kwargs = {"dtype": np.float64, "order": "C"} @@ -118,7 +120,7 @@ def __init__(self, adata, basis=None): self.initial_cell_size = self.S.sum(0) self.initial_Ucell_size = self.U.sum(0) - from ..preprocessing.utils import not_yet_normalized + from scvelo.preprocessing.utils import not_yet_normalized if not not_yet_normalized(adata.layers["spliced"]): self.S_sz = self.S @@ -308,9 +310,9 @@ def run_all( def test(): - from ..datasets import simulation + from scvelo.datasets import simulation + from scvelo.logging import print_version from .velocity_graph import velocity_graph - from ..logging import print_version print_version() adata = simulation(n_obs=300, n_vars=30) diff --git a/scvelo/tools/score_genes_cell_cycle.py b/scvelo/tools/score_genes_cell_cycle.py index 3a26d5ce..9f5b0933 100644 --- a/scvelo/tools/score_genes_cell_cycle.py +++ b/scvelo/tools/score_genes_cell_cycle.py @@ -1,8 +1,7 @@ -from .. import logging as logg - -import pandas as pd import numpy as np +import pandas as pd +from scvelo import logging as logg # fmt: off s_genes_list = \ @@ -32,10 +31,12 @@ def get_phase_marker_genes(adata): adata The annotated data matrix. phase + Returns ------- List of S and/or G2M phase marker genes. """ + name, gene_names = adata.var_names[0], adata.var_names up, low = name.isupper(), name.islower() s_genes_list_ = [ @@ -70,16 +71,17 @@ def score_genes_cell_cycle(adata, s_genes=None, g2m_genes=None, copy=False, **kw **kwargs Are passed to :func:`~scanpy.tl.score_genes`. `ctrl_size` is not possible, as it's set as `min(len(s_genes), len(g2m_genes))`. + Returns ------- - Depending on `copy`, returns or updates `adata` with the following fields. - **S_score** : `adata.obs`, dtype `object` + S_score: `adata.obs`, dtype `object` The score for S phase for each cell. - **G2M_score** : `adata.obs`, dtype `object` + G2M_score: `adata.obs`, dtype `object` The score for G2M phase for each cell. - **phase** : `adata.obs`, dtype `object` + phase: `adata.obs`, dtype `object` The cell cycle phase (`S`, `G2M` or `G1`) for each cell. """ + logg.info("calculating cell cycle phase") from scanpy.tools._score_genes import score_genes diff --git a/scvelo/tools/terminal_states.py b/scvelo/tools/terminal_states.py index 99a9fc23..5f740dcf 100644 --- a/scvelo/tools/terminal_states.py +++ b/scvelo/tools/terminal_states.py @@ -1,13 +1,13 @@ -from .. import settings -from .. import logging as logg -from ..preprocessing.moments import get_connectivities -from ..preprocessing.neighbors import verify_neighbors -from .velocity_graph import VelocityGraph -from .transition_matrix import transition_matrix -from .utils import scale, groups_to_bool, strings_to_categoricals, get_plasticity_score - -from scipy.sparse import linalg, csr_matrix, issparse import numpy as np +from scipy.sparse import csr_matrix, issparse, linalg + +from scvelo import logging as logg +from scvelo import settings +from scvelo.preprocessing.moments import get_connectivities +from scvelo.preprocessing.neighbors import verify_neighbors +from .transition_matrix import transition_matrix +from .utils import get_plasticity_score, groups_to_bool, scale, strings_to_categoricals +from .velocity_graph import VelocityGraph def cell_fate( @@ -37,12 +37,12 @@ def cell_fate( Returns ------- - Returns or updates `adata` with the attributes cell_fate: `.obs` most likely cell fate for each individual cell cell_fate_confidence: `.obs` confidence of transitioning to the assigned fate """ + adata = data.copy() if copy else data logg.info("computing cell fates", r=True) @@ -56,8 +56,7 @@ def cell_fate( _adata.uns["velocity_graph_neg"] = vgraph.graph_neg T = transition_matrix(_adata, self_transitions=self_transitions) - I = np.eye(_adata.n_obs) - fate = np.linalg.inv(I - T) + fate = np.linalg.inv(np.eye(_adata.n_obs) - T) if issparse(T): fate = fate.A cell_fates = np.array(_adata.obs[groupby][fate.argmax(1)]) @@ -105,12 +104,12 @@ def cell_origin( Returns ------- - Returns or updates `adata` with the attributes cell_origin: `.obs` most likely cell origin for each individual cell cell_origin_confidence: `.obs` confidence of coming from assigned origin """ + adata = data.copy() if copy else data logg.info("computing cell fates", r=True) @@ -124,8 +123,7 @@ def cell_origin( _adata.uns["velocity_graph_neg"] = vgraph.graph_neg T = transition_matrix(_adata, self_transitions=self_transitions, backward=True) - I = np.eye(_adata.n_obs) - fate = np.linalg.inv(I - T) + fate = np.linalg.inv(np.eye(_adata.n_obs) - T) if issparse(T): fate = fate.A cell_fates = np.array(_adata.obs[groupby][fate.argmax(1)]) @@ -167,7 +165,7 @@ def eigs(T, k=10, eps=1e-3, perc=None, random_state=None, v0=None): eigvecs = np.clip(eigvecs, 0, ubs) eigvecs /= eigvecs.max(0) - except: + except Exception: eigvals, eigvecs = np.empty(0), np.zeros(shape=(T.shape[0], 0)) return eigvals, eigvecs @@ -222,7 +220,7 @@ def terminal_states( which is given by left eigenvectors corresponding to an eigenvalue of 1, i.e. .. math:: - μ^{\\textrm{end}}=μ^{\\textrm{end}} \\pi, \quad + μ^{\\textrm{end}}=μ^{\\textrm{end}} \\pi, \\quad μ^{\\textrm{root}}=μ^{\\textrm{root}} \\pi^{\\small \\textrm{T}}. .. code:: python @@ -263,12 +261,12 @@ def terminal_states( Returns ------- - Returns or updates `data` with the attributes root_cells: `.obs` sparse matrix with transition probabilities. end_points: `.obs` sparse matrix with transition probabilities. - """ + """ # noqa E501 + adata = data.copy() if copy else data verify_neighbors(adata) diff --git a/scvelo/tools/transition_matrix.py b/scvelo/tools/transition_matrix.py index 2803ad7d..5279c86d 100644 --- a/scvelo/tools/transition_matrix.py +++ b/scvelo/tools/transition_matrix.py @@ -1,12 +1,12 @@ -from ..preprocessing.neighbors import get_connectivities, get_neighs -from .utils import normalize +import warnings import numpy as np import pandas as pd -from scipy.spatial.distance import pdist, squareform from scipy.sparse import csr_matrix, SparseEfficiencyWarning +from scipy.spatial.distance import pdist, squareform -import warnings +from scvelo.preprocessing.neighbors import get_connectivities, get_neighs +from .utils import normalize warnings.simplefilter("ignore", SparseEfficiencyWarning) @@ -36,7 +36,8 @@ def transition_matrix( from the velocity graph :math:`\\pi_{ij}`, with row-normalization :math:`z_i` and kernel width :math:`\\sigma` (scale parameter :math:`\\lambda = \\sigma^{-1}`). - Alternatively, use :func:`cellrank.tl.transition_matrix` to account for uncertainty in the velocity estimates. + Alternatively, use :func:`cellrank.tl.transition_matrix` to account for uncertainty + in the velocity estimates. Arguments --------- @@ -72,6 +73,7 @@ def transition_matrix( ------- Returns sparse matrix with transition probabilities. """ + if f"{vkey}_graph" not in adata.uns: raise ValueError( "You need to run `tl.velocity_graph` first to compute cosine correlations." @@ -185,11 +187,13 @@ def get_cell_transitions( Set to `int` for reproducibility, otherwise `None` for a random seed. **kwargs: To be passed to tl.transition_matrix. + Returns ------- Returns embedding coordinates (if basis is specified), otherwise return indices of simulated cell transitions. """ + np.random.seed(random_state) if isinstance(starting_cell, str) and starting_cell in adata.obs_names: starting_cell = adata.obs_names.get_loc(starting_cell) @@ -207,6 +211,8 @@ def get_cell_transitions( if n_neighbors is not None and n_neighbors < len(p): idx = np.argsort(t.data)[::-1][:n_neighbors] indices, p = indices[idx], p[idx] + if len(p) == 0: + indices, p = [X[-1]], [1] p /= np.sum(p) ix = np.random.choice(indices, p=p) X.append(ix) diff --git a/scvelo/tools/utils.py b/scvelo/tools/utils.py index d06993d5..2465a102 100644 --- a/scvelo/tools/utils.py +++ b/scvelo/tools/utils.py @@ -1,8 +1,12 @@ +import warnings + +import numpy as np +import pandas as pd from scipy.sparse import csr_matrix, issparse + import matplotlib.pyplot as pl -import pandas as pd -import numpy as np -import warnings + +from scvelo.core import l2_norm, prod_sum, sum warnings.simplefilter("ignore") @@ -12,9 +16,9 @@ def round(k, dec=2, as_str=None): return [round(ki, dec) for ki in k] if "e" in f"{k}": k_str = f"{k}".split("e") - result = f"{np.round(np.float(k_str[0]), dec)}1e{k_str[1]}" - return f"{result}" if as_str else np.float(result) - result = np.round(np.float(k), dec) + result = f"{np.round(float(k_str[0]), dec)}1e{k_str[1]}" + return f"{result}" if as_str else float(result) + result = np.round(float(k), dec) return f"{result}" if as_str else result @@ -31,72 +35,109 @@ def make_dense(X): def sum_obs(A): """summation over axis 0 (obs) equivalent to np.sum(A, 0)""" - if issparse(A): - return A.sum(0).A1 - else: - return np.einsum("ij -> j", A) if A.ndim > 1 else np.sum(A) + + warnings.warn( + "`sum_obs` is deprecated since scVelo v0.2.4 and will be removed in a future " + "version. Please use `sum(A, axis=0)` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return sum(A, axis=0) def sum_var(A): """summation over axis 1 (var) equivalent to np.sum(A, 1)""" - if issparse(A): - return A.sum(1).A1 - else: - return np.sum(A, axis=1) if A.ndim > 1 else np.sum(A) + + warnings.warn( + "`sum_var` is deprecated since scVelo v0.2.4 and will be removed in a future " + "version. Please use `sum(A, axis=1)` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return sum(A, axis=1) def prod_sum_obs(A, B): """dot product and sum over axis 0 (obs) equivalent to np.sum(A * B, 0)""" - if issparse(A): - return A.multiply(B).sum(0).A1 - else: - return np.einsum("ij, ij -> j", A, B) if A.ndim > 1 else (A * B).sum() + + warnings.warn( + "`prod_sum_obs` is deprecated since scVelo v0.2.4 and will be removed in a " + "future version. Please use `prod_sum(A, B, axis=0)` from `scvelo/core/` " + "instead.", + DeprecationWarning, + stacklevel=2, + ) + + return prod_sum(A, B, axis=0) def prod_sum_var(A, B): """dot product and sum over axis 1 (var) equivalent to np.sum(A * B, 1)""" - if issparse(A): - return A.multiply(B).sum(1).A1 - else: - return np.einsum("ij, ij -> i", A, B) if A.ndim > 1 else (A * B).sum() + + warnings.warn( + "`prod_sum_var` is deprecated since scVelo v0.2.4 and will be removed in a " + "future version. Please use `prod_sum(A, B, axis=1)` from `scvelo/core/` " + "instead.", + DeprecationWarning, + stacklevel=2, + ) + + return prod_sum(A, B, axis=1) def norm(A): """computes the L2-norm along axis 1 (e.g. genes or embedding dimensions) equivalent to np.linalg.norm(A, axis=1) """ - if issparse(A): - return np.sqrt(A.multiply(A).sum(1).A1) - else: - return np.sqrt(np.einsum("ij, ij -> i", A, A) if A.ndim > 1 else np.sum(A * A)) + + warnings.warn( + "`norm` is deprecated since scVelo v0.2.4 and will be removed in a future " + "version. Please use `l2_norm(A, axis=1)` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return l2_norm(A, axis=1) def vector_norm(x): """computes the L2-norm along axis 1 (e.g. genes or embedding dimensions) equivalent to np.linalg.norm(A, axis=1) """ - return np.sqrt(np.einsum("i, i -> ", x, x)) + + warnings.warn( + "`vector_norm` is deprecated since scVelo v0.2.4 and will be removed in a " + "future version. Please use `l2_norm(x)` from `scvelo/core/` instead.", + DeprecationWarning, + stacklevel=2, + ) + + return l2_norm(x) def R_squared(residual, total): with warnings.catch_warnings(): warnings.simplefilter("ignore") - r2 = np.ones(residual.shape[1]) - prod_sum_obs( - residual, residual - ) / prod_sum_obs(total, total) + r2 = np.ones(residual.shape[1]) - prod_sum( + residual, residual, axis=0 + ) / prod_sum(total, total, axis=0) r2[np.isnan(r2)] = 0 return r2 def cosine_correlation(dX, Vi): dx = dX - dX.mean(-1)[:, None] - Vi_norm = vector_norm(Vi) + Vi_norm = l2_norm(Vi, axis=0) with warnings.catch_warnings(): warnings.simplefilter("ignore") if Vi_norm == 0: result = np.zeros(dx.shape[0]) else: - result = np.einsum("ij, j", dx, Vi) / (norm(dx) * Vi_norm)[None, :] + result = ( + np.einsum("ij, j", dx, Vi) / (l2_norm(dx, axis=1) * Vi_norm)[None, :] + ) return result @@ -119,12 +160,12 @@ def scale(X, min=0, max=1): def get_indices(dist, n_neighbors=None, mode_neighbors="distances"): - from ..preprocessing.neighbors import compute_connectivities_umap + from scvelo.preprocessing.neighbors import compute_connectivities_umap D = dist.copy() D.data += 1e-6 - n_counts = sum_var(D > 0) + n_counts = sum(D > 0, axis=1) n_neighbors = ( n_counts.min() if n_neighbors is None else min(n_counts.min(), n_neighbors) ) @@ -221,8 +262,8 @@ def randomized_velocity(adata, vkey="velocity", add_key="velocity_random"): ) adata.layers[add_key] = V_rnd - from .velocity_graph import velocity_graph from .velocity_embedding import velocity_embedding + from .velocity_graph import velocity_graph velocity_graph(adata, vkey=add_key) velocity_embedding(adata, vkey=add_key, autoscale=False) @@ -248,8 +289,8 @@ def str_to_int(item): def strings_to_categoricals(adata): """Transform string annotations to categoricals.""" - from pandas.api.types import is_string_dtype, is_integer_dtype, is_bool_dtype from pandas import Categorical + from pandas.api.types import is_bool_dtype, is_integer_dtype, is_string_dtype def is_valid_dtype(values): return ( @@ -341,15 +382,15 @@ def cutoff_small_velocities( adata.layers[key_added] = csr_matrix(W).multiply(adata.layers[vkey]).tocsr() - from .velocity_graph import velocity_graph from .velocity_embedding import velocity_embedding + from .velocity_graph import velocity_graph velocity_graph(adata, vkey=key_added, approx=True) velocity_embedding(adata, vkey=key_added) def make_unique_list(key, allow_array=False): - from pandas import unique, Index + from pandas import Index, unique if isinstance(key, Index): key = key.tolist() @@ -375,7 +416,8 @@ def test_bimodality(x, bins=30, kde=True, plot=False): ) idx = int(bins / 2) - 2 - idx += np.argmin(kde_grid[idx : idx + 4]) + end = idx + 4 + idx += np.argmin(kde_grid[idx:end]) peak_0 = kde_grid[:idx].argmax() peak_1 = kde_grid[idx:].argmax() @@ -477,7 +519,7 @@ def indices_to_bool(indices, n): def convolve(adata, x): - from ..preprocessing.neighbors import get_connectivities + from scvelo.preprocessing.neighbors import get_connectivities conn = get_connectivities(adata) if isinstance(x, str) and x in adata.layers.keys(): diff --git a/scvelo/tools/velocity.py b/scvelo/tools/velocity.py index 3e26826e..eadc53fb 100644 --- a/scvelo/tools/velocity.py +++ b/scvelo/tools/velocity.py @@ -1,11 +1,13 @@ -from .. import settings -from .. import logging as logg -from ..preprocessing.moments import moments, second_order_moments, get_connectivities -from .optimization import leastsq_NxN, leastsq_generalized, maximum_likelihood -from .utils import R_squared, groups_to_bool, make_dense, strings_to_categoricals +import warnings import numpy as np -import warnings + +from scvelo import logging as logg +from scvelo import settings +from scvelo.core import LinearRegression +from scvelo.preprocessing.moments import moments, second_order_moments +from .optimization import leastsq_generalized, maximum_likelihood +from .utils import groups_to_bool, make_dense, R_squared, strings_to_categoricals warnings.simplefilter(action="ignore", category=FutureWarning) @@ -57,7 +59,11 @@ def compute_deterministic(self, fit_offset=False, perc=None): subset = self._groups_for_fit Ms = self._Ms if subset is None else self._Ms[subset] Mu = self._Mu if subset is None else self._Mu[subset] - self._offset, self._gamma = leastsq_NxN(Ms, Mu, fit_offset, perc) + + lr = LinearRegression(fit_intercept=fit_offset, percentile=perc) + lr.fit(Ms, Mu) + self._offset = lr.intercept_ + self._gamma = lr.coef_ if self._constrain_ratio is not None: if np.size(self._constrain_ratio) < 2: @@ -72,7 +78,11 @@ def compute_deterministic(self, fit_offset=False, perc=None): # velocity genes if self._r2_adjusted: - _offset, _gamma = leastsq_NxN(Ms, Mu, fit_offset) + lr = LinearRegression(fit_intercept=fit_offset) + lr.fit(Ms, Mu) + _offset = lr.intercept_ + _gamma = lr.coef_ + _residual = self._Mu - _gamma * self._Ms if fit_offset: _residual -= _offset @@ -121,7 +131,10 @@ def compute_stochastic( var_ss = 2 * _Mss - _Ms cov_us = 2 * _Mus + _Mu - _offset2, _gamma2 = leastsq_NxN(var_ss, cov_us, fit_offset2) + lr = LinearRegression(fit_intercept=fit_offset2) + lr.fit(var_ss, cov_us) + _offset2 = lr.intercept_ + _gamma2 = lr.coef_ # initialize covariance matrix res_std = _residual.std(0) @@ -286,14 +299,12 @@ def velocity( Returns ------- - Returns or updates `adata` with the attributes velocity: `.layers` velocity vectors for each individual cell - variance_velocity: `.layers` - velocity vectors for the cell variances - velocity_offset, velocity_beta, velocity_gamma, velocity_r2: `.var` + velocity_genes, velocity_beta, velocity_gamma, velocity_r2: `.var` parameters - """ + """ # noqa E501 + adata = data.copy() if copy else data if not use_raw and "Ms" not in adata.layers.keys(): moments(adata) @@ -310,7 +321,7 @@ def velocity( ) if mode in {"dynamical", "dynamical_residuals"}: - from .dynamical_model_utils import get_reads, get_vars, get_divergence + from .dynamical_model_utils import get_divergence, get_reads, get_vars gene_subset = ~np.isnan(adata.var["fit_alpha"].values) vdata = adata[:, gene_subset] @@ -492,6 +503,7 @@ def velocity_genes( velocity_genes: `.var` genes to be used for further velocity analysis (velocity graph and embedding) """ + adata = data.copy() if copy else data if f"{vkey}_genes" not in adata.var.keys(): velocity(adata, vkey) @@ -512,8 +524,8 @@ def velocity_genes( if np.sum(vgenes) < 2: logg.warn( - f"You seem to have very low signal in splicing dynamics.\n" - f"Consider reducing the thresholds and be cautious with interpretations.\n" + "You seem to have very low signal in splicing dynamics.\n" + "Consider reducing the thresholds and be cautious with interpretations.\n" ) adata.var[f"{vkey}_genes"] = vgenes diff --git a/scvelo/tools/velocity_confidence.py b/scvelo/tools/velocity_confidence.py index 0e60d54c..85db17a6 100644 --- a/scvelo/tools/velocity_confidence.py +++ b/scvelo/tools/velocity_confidence.py @@ -1,10 +1,11 @@ -from .. import logging as logg -from ..preprocessing.neighbors import get_neighs -from .utils import prod_sum_var, norm, get_indices, random_subsample -from .transition_matrix import transition_matrix - import numpy as np +from scvelo import logging as logg +from scvelo.core import l2_norm, prod_sum +from scvelo.preprocessing.neighbors import get_neighs +from .transition_matrix import transition_matrix +from .utils import get_indices, random_subsample + def velocity_confidence(data, vkey="velocity", copy=False): """Computes confidences of velocities. @@ -29,12 +30,12 @@ def velocity_confidence(data, vkey="velocity", copy=False): Returns ------- - Returns or updates `adata` with the attributes velocity_length: `.obs` Length of the velocity vectors for each individual cell velocity_confidence: `.obs` Confidence for each cell - """ + """ # noqa E501 + adata = data.copy() if copy else data if vkey not in adata.layers.keys(): raise ValueError("You need to run `tl.velocity` first.") @@ -50,7 +51,7 @@ def velocity_confidence(data, vkey="velocity", copy=False): V = V[:, tmp_filter] V -= V.mean(1)[:, None] - V_norm = norm(V) + V_norm = l2_norm(V, axis=1) R = np.zeros(adata.n_obs) indices = get_indices(dist=get_neighs(adata, "distances"))[0] @@ -58,7 +59,8 @@ def velocity_confidence(data, vkey="velocity", copy=False): Vi_neighs = V[indices[i]] Vi_neighs -= Vi_neighs.mean(1)[:, None] R[i] = np.mean( - np.einsum("ij, j", Vi_neighs, V[i]) / (norm(Vi_neighs) * V_norm[i])[None, :] + np.einsum("ij, j", Vi_neighs, V[i]) + / (l2_norm(Vi_neighs, axis=1) * V_norm[i])[None, :] ) adata.obs[f"{vkey}_length"] = V_norm.round(2) @@ -89,10 +91,10 @@ def velocity_confidence_transition(data, vkey="velocity", scale=10, copy=False): Returns ------- - Returns or updates `adata` with the attributes velocity_confidence_transition: `.obs` Confidence of transition for each cell """ + adata = data.copy() if copy else data if vkey not in adata.layers.keys(): raise ValueError("You need to run `tl.velocity` first.") @@ -114,10 +116,10 @@ def velocity_confidence_transition(data, vkey="velocity", scale=10, copy=False): dX -= dX.mean(1)[:, None] V -= V.mean(1)[:, None] - norms = norm(dX) * norm(V) + norms = l2_norm(dX, axis=1) * l2_norm(V, axis=1) norms += norms == 0 - adata.obs[f"{vkey}_confidence_transition"] = prod_sum_var(dX, V) / norms + adata.obs[f"{vkey}_confidence_transition"] = prod_sum(dX, V, axis=1) / norms logg.hint(f"added '{vkey}_confidence_transition' (adata.obs)") @@ -130,8 +132,8 @@ def score_robustness( adata = data.copy() if copy else data if adata_subset is None: - from ..preprocessing.moments import moments - from ..preprocessing.neighbors import neighbors + from scvelo.preprocessing.moments import moments + from scvelo.preprocessing.neighbors import neighbors from .velocity import velocity logg.switch_verbosity("off") @@ -147,8 +149,10 @@ def score_robustness( V = adata[subset].layers[vkey] V_subset = adata_subset.layers[vkey] - score = np.nan * (subset == False) - score[subset] = prod_sum_var(V, V_subset) / (norm(V) * norm(V_subset)) + score = np.nan * (subset is False) + score[subset] = prod_sum(V, V_subset, axis=1) / ( + l2_norm(V, axis=1) * l2_norm(V_subset, axis=1) + ) adata.obs[f"{vkey}_score_robustness"] = score return adata_subset if copy else None diff --git a/scvelo/tools/velocity_embedding.py b/scvelo/tools/velocity_embedding.py index a5521fc4..d62c46ef 100644 --- a/scvelo/tools/velocity_embedding.py +++ b/scvelo/tools/velocity_embedding.py @@ -1,11 +1,12 @@ -from .. import settings -from .. import logging as logg -from .utils import norm -from .transition_matrix import transition_matrix +import warnings -from scipy.sparse import issparse import numpy as np -import warnings +from scipy.sparse import issparse + +from scvelo import logging as logg +from scvelo import settings +from scvelo.core import l2_norm +from .transition_matrix import transition_matrix def quiver_autoscale(X_emb, V_emb): @@ -45,13 +46,15 @@ def velocity_embedding( """Projects the single cell velocities into any embedding. Given normalized difference of the embedding positions - :math:`\\tilde \\delta_{ij} = \\frac{x_j-x_i}{\\left\\lVert x_j-x_i \\right\\rVert}`. + :math: + `\\tilde \\delta_{ij} = \\frac{x_j-x_i}{\\left\\lVert x_j-x_i \\right\\rVert}`. the projections are obtained as expected displacements with respect to the transition matrix :math:`\\tilde \\pi_{ij}` as .. math:: \\tilde \\nu_i = E_{\\tilde \\pi_{i\\cdot}} [\\tilde \\delta_{i \\cdot}] - = \\sum_{j \\neq i} \left( \\tilde \\pi_{ij} - \\frac1n \\right) \\tilde \\delta_{ij}. + = \\sum_{j \\neq i} \\left( \\tilde \\pi_{ij} - \\frac1n \\right) \\tilde \\ + delta_{ij}. Arguments @@ -87,10 +90,10 @@ def velocity_embedding( Returns ------- - Returns or updates `adata` with the attributes - velocity_basis: `.obsm` - coordinates of velocity projection on embedding + velocity_umap: `.obsm` + coordinates of velocity projection on embedding (e.g., basis='umap') """ + adata = data.copy() if copy else data if basis is None: @@ -107,9 +110,12 @@ def velocity_embedding( if direct_pca_projection and "pca" in basis: logg.warn( - "Directly projecting velocities into PCA space is for exploratory analysis on principal components.\n" - " It does not reflect the actual velocity field from high dimensional gene expression space.\n" - " To visualize velocities, consider applying `direct_pca_projection=False`.\n" + "Directly projecting velocities into PCA space is for exploratory analysis " + "on principal components.\n" + " It does not reflect the actual velocity field from high " + "dimensional gene expression space.\n" + " To visualize velocities, consider applying " + "`direct_pca_projection=False`.\n" ) logg.info("computing velocity embedding", r=True) @@ -157,7 +163,7 @@ def velocity_embedding( indices = T[i].indices dX = X_emb[indices] - X_emb[i, None] # shape (n_neighbors, 2) if not retain_scale: - dX /= norm(dX)[:, None] + dX /= l2_norm(dX)[:, None] dX[np.isnan(dX)] = 0 # zero diff in a steady-state probs = TA[i, indices] if densify else T[i].data V_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0) @@ -171,7 +177,7 @@ def velocity_embedding( delta = T.dot(X[:, vgenes]) - X[:, vgenes] if issparse(delta): delta = delta.A - cos_proj = (V * delta).sum(1) / norm(delta) + cos_proj = (V * delta).sum(1) / l2_norm(delta) V_emb *= np.clip(cos_proj[:, None] * 10, 0, 1) if autoscale: diff --git a/scvelo/tools/velocity_graph.py b/scvelo/tools/velocity_graph.py index 73ce76f6..a1a5be4b 100644 --- a/scvelo/tools/velocity_graph.py +++ b/scvelo/tools/velocity_graph.py @@ -1,13 +1,21 @@ -from .. import settings -from .. import logging as logg -from ..preprocessing.neighbors import pca, neighbors, verify_neighbors -from ..preprocessing.neighbors import get_neighs, get_n_neighs -from ..preprocessing.moments import get_moments -from .utils import cosine_correlation, get_indices, get_iterative_indices, norm -from .velocity import velocity +import os -from scipy.sparse import coo_matrix, issparse import numpy as np +from scipy.sparse import coo_matrix, issparse + +from scvelo import logging as logg +from scvelo import settings +from scvelo.core import get_n_jobs, l2_norm, parallelize +from scvelo.preprocessing.moments import get_moments +from scvelo.preprocessing.neighbors import ( + get_n_neighs, + get_neighs, + neighbors, + pca, + verify_neighbors, +) +from .utils import cosine_correlation, get_indices, get_iterative_indices +from .velocity import velocity def vals_to_csr(vals, rows, cols, shape, split_negative=False): @@ -83,8 +91,10 @@ def __init__( self.V_raw = np.array(self.V) self.sqrt_transform = sqrt_transform - if self.sqrt_transform is None and f"{vkey}_params" in adata.uns.keys(): - self.sqrt_transform = adata.uns[f"{vkey}_params"]["mode"] == "stochastic" + uns_key = f"{vkey}_params" + if self.sqrt_transform is None: + if uns_key in adata.uns.keys() and "mode" in adata.uns[uns_key]: + self.sqrt_transform = adata.uns[uns_key]["mode"] == "stochastic" if self.sqrt_transform: self.V = np.sqrt(np.abs(self.V)) * np.sign(self.V) self.V -= np.nanmean(self.V, axis=1)[:, None] @@ -125,7 +135,7 @@ def __init__( mode="connectivity" ).indices.reshape((-1, n_neighbors + 1)) else: - from .. import Neighbors + from scvelo import Neighbors neighs = Neighbors(adata) neighs.compute_neighbors( @@ -156,21 +166,51 @@ def __init__( self.report = report self.adata = adata - def compute_cosines(self): - vals, rows, cols, uncertainties, n_obs = [], [], [], [], self.X.shape[0] - progress = logg.ProgressReporter(n_obs) + def compute_cosines(self, n_jobs=None, backend="loky"): + n_jobs = get_n_jobs(n_jobs=n_jobs) + + n_obs = self.X.shape[0] + + # TODO: Use batches and vectorize calculation of dX in self._calculate_cosines + res = parallelize( + self._compute_cosines, + range(self.X.shape[0]), + n_jobs=n_jobs, + unit="cells", + backend=backend, + )() + uncertainties, vals, rows, cols = map(_flatten, zip(*res)) + + vals = np.hstack(vals) + vals[np.isnan(vals)] = 0 + + self.graph, self.graph_neg = vals_to_csr( + vals, rows, cols, shape=(n_obs, n_obs), split_negative=True + ) + if self.compute_uncertainties: + uncertainties = np.hstack(uncertainties) + uncertainties[np.isnan(uncertainties)] = 0 + self.uncertainties = vals_to_csr( + uncertainties, rows, cols, shape=(n_obs, n_obs), split_negative=False + ) + self.uncertainties.eliminate_zeros() + + confidence = self.graph.max(1).A.flatten() + self.self_prob = np.clip(np.percentile(confidence, 98) - confidence, 0, 1) + def _compute_cosines(self, obs_idx, queue): + vals, rows, cols, uncertainties = [], [], [], [] if self.compute_uncertainties: - m = get_moments(self.adata, np.sign(self.V_raw), second_order=True) + moments = get_moments(self.adata, np.sign(self.V_raw), second_order=True) - for i in range(n_obs): - if self.V[i].max() != 0 or self.V[i].min() != 0: + for obs_id in obs_idx: + if self.V[obs_id].max() != 0 or self.V[obs_id].min() != 0: neighs_idx = get_iterative_indices( - self.indices, i, self.n_recurse_neighbors, self.max_neighs + self.indices, obs_id, self.n_recurse_neighbors, self.max_neighs ) if self.t0 is not None: - t0, t1 = self.t0[i], self.t1[i] + t0, t1 = self.t0[obs_id], self.t1[obs_id] if t0 >= 0 and t1 > 0: t1_idx = np.where(self.t0 == t1)[0] if len(t1_idx) > len(neighs_idx): @@ -180,39 +220,32 @@ def compute_cosines(self): if len(t1_idx) > 0: neighs_idx = np.unique(np.concatenate([neighs_idx, t1_idx])) - dX = self.X[neighs_idx] - self.X[i, None] # 60% of runtime + dX = self.X[neighs_idx] - self.X[obs_id, None] # 60% of runtime if self.sqrt_transform: dX = np.sqrt(np.abs(dX)) * np.sign(dX) - val = cosine_correlation(dX, self.V[i]) # 40% of runtime + val = cosine_correlation(dX, self.V[obs_id]) # 40% of runtime if self.compute_uncertainties: - dX /= norm(dX)[:, None] - uncertainties.extend(np.nansum(dX ** 2 * m[i][None, :], 1)) + dX /= l2_norm(dX)[:, None] + uncertainties.extend( + np.nansum(dX ** 2 * moments[obs_id][None, :], 1) + ) vals.extend(val) - rows.extend(np.ones(len(neighs_idx)) * i) + rows.extend(np.ones(len(neighs_idx)) * obs_id) cols.extend(neighs_idx) - if self.report: - progress.update() - if self.report: - progress.finish() - vals = np.hstack(vals) - vals[np.isnan(vals)] = 0 + if queue is not None: + queue.put(1) - self.graph, self.graph_neg = vals_to_csr( - vals, rows, cols, shape=(n_obs, n_obs), split_negative=True - ) - if self.compute_uncertainties: - uncertainties = np.hstack(uncertainties) - uncertainties[np.isnan(uncertainties)] = 0 - self.uncertainties = vals_to_csr( - uncertainties, rows, cols, shape=(n_obs, n_obs), split_negative=False - ) - self.uncertainties.eliminate_zeros() + if queue is not None: + queue.put(None) - confidence = self.graph.max(1).A.flatten() - self.self_prob = np.clip(np.percentile(confidence, 98) - confidence, 0, 1) + return uncertainties, vals, rows, cols + + +def _flatten(iterable): + return [i for it in iterable for i in it] def velocity_graph( @@ -231,6 +264,8 @@ def velocity_graph( approx=None, mode_neighbors="distances", copy=False, + n_jobs=None, + backend="loky", ): """Computes velocity graph based on cosine similarities. @@ -241,7 +276,8 @@ def velocity_graph( .. math:: \\pi_{ij} = \\cos\\angle(\\delta_{ij}, \\nu_i) - = \\frac{\\delta_{ij}^T \\nu_i}{\\left\\lVert\\delta_{ij}\\right\\rVert \\left\\lVert \\nu_i \\right\\rVert}. + = \\frac{\\delta_{ij}^T \\nu_i}{\\left\\lVert\\delta_{ij}\\right\\rVert + \\left\\lVert \\nu_i \\right\\rVert}. Arguments --------- @@ -277,13 +313,18 @@ def velocity_graph( 'connectivities'. The latter yields a symmetric graph. copy: `bool` (default: `False`) Return a copy instead of writing to adata. + n_jobs: `int` or `None` (default: `None`) + Number of parallel jobs. + backend: `str` (default: "loky") + Backend used for multiprocessing. See :class:`joblib.Parallel` for valid + options. Returns ------- - Returns or updates `adata` with the attributes velocity_graph: `.uns` - sparse matrix with transition probabilities + sparse matrix with correlations of cell state transitions with velocities """ + adata = data.copy() if copy else data verify_neighbors(adata) if vkey not in adata.layers.keys(): @@ -315,8 +356,11 @@ def velocity_graph( f" on full expression space by not specifying basis.\n" ) - logg.info("computing velocity graph", r=True) - vgraph.compute_cosines() + n_jobs = get_n_jobs(n_jobs=n_jobs) + logg.info( + f"computing velocity graph (using {n_jobs}/{os.cpu_count()} cores)", r=True + ) + vgraph.compute_cosines(n_jobs=n_jobs, backend=backend) adata.uns[f"{vkey}_graph"] = vgraph.graph adata.uns[f"{vkey}_graph_neg"] = vgraph.graph_neg diff --git a/scvelo/tools/velocity_pseudotime.py b/scvelo/tools/velocity_pseudotime.py index 9b10d8f8..f92e9b05 100644 --- a/scvelo/tools/velocity_pseudotime.py +++ b/scvelo/tools/velocity_pseudotime.py @@ -1,12 +1,13 @@ -from .. import logging as logg -from .terminal_states import terminal_states -from .utils import groups_to_bool, scale, strings_to_categoricals -from ..preprocessing.moments import get_connectivities - import numpy as np -from scipy.sparse import issparse, spdiags, linalg +from scipy.sparse import issparse, linalg, spdiags + from scanpy.tools._dpt import DPT +from scvelo import logging as logg +from scvelo.preprocessing.moments import get_connectivities +from .terminal_states import terminal_states +from .utils import groups_to_bool, scale, strings_to_categoricals + def principal_curve(data, basis="pca", n_comps=4, clusters_list=None, copy=False): """Computes the principal curve @@ -20,12 +21,13 @@ def principal_curve(data, basis="pca", n_comps=4, clusters_list=None, copy=False Number of pricipal components to be used. copy: `bool`, (default: `False`) Return a copy instead of writing to adata. + Returns ------- - Returns or updates `adata` with the attributes principal_curve: `.uns` dictionary containing `projections`, `ixsort` and `arclength` """ + adata = data.copy() if copy else data import rpy2.robjects as robjects from rpy2.robjects.packages import importr @@ -182,12 +184,12 @@ def velocity_pseudotime( **kwargs: Further arguments to pass to VPT (e.g. min_group_size, allow_kendall_tau_shift). - Returns + Returns ------- - Updates `adata` with the attributes velocity_pseudotime: `.obs` Velocity pseudotime obtained from velocity graph. - """ + """ # noqa E501 + strings_to_categoricals(adata) if root_key is None and "root_cells" in adata.obs.keys(): root0 = adata.obs["root_cells"][0] diff --git a/scvelo/utils.py b/scvelo/utils.py index c9d5eb79..c265ff31 100644 --- a/scvelo/utils.py +++ b/scvelo/utils.py @@ -1,24 +1,62 @@ -from .preprocessing.utils import show_proportions, cleanup -from .preprocessing.utils import set_initial_size, get_initial_size +from scvelo.core import ( + clean_obs_names, + cleanup, + get_initial_size, + merge, + set_initial_size, + show_proportions, +) +from scvelo.plotting.simulation import compute_dynamics +from scvelo.plotting.utils import ( + clip, + interpret_colorkey, + is_categorical, + rgb_custom_colormap, +) +from scvelo.plotting.velocity_embedding_grid import compute_velocity_on_grid +from scvelo.preprocessing.moments import get_moments +from scvelo.preprocessing.neighbors import get_connectivities +from scvelo.read_load import ( + convert_to_ensembl, + convert_to_gene_names, + gene_info, + load_biomart, +) +from scvelo.tools.optimization import get_weight, leastsq +from scvelo.tools.rank_velocity_genes import get_mean_var +from scvelo.tools.run import convert_to_adata, convert_to_loom +from scvelo.tools.score_genes_cell_cycle import get_phase_marker_genes +from scvelo.tools.transition_matrix import get_cell_transitions +from scvelo.tools.transition_matrix import transition_matrix as get_transition_matrix +from scvelo.tools.utils import * # noqa +from scvelo.tools.velocity_graph import vals_to_csr -from .preprocessing.neighbors import get_connectivities -from .preprocessing.moments import get_moments - -from .tools.utils import * -from .tools.rank_velocity_genes import get_mean_var -from .tools.run import convert_to_adata, convert_to_loom -from .tools.optimization import leastsq, get_weight -from .tools.velocity_graph import vals_to_csr -from .tools.score_genes_cell_cycle import get_phase_marker_genes - -from .tools.transition_matrix import transition_matrix as get_transition_matrix -from .tools.transition_matrix import get_cell_transitions - -from .plotting.utils import is_categorical, clip -from .plotting.utils import interpret_colorkey, rgb_custom_colormap - -from .plotting.velocity_embedding_grid import compute_velocity_on_grid -from .plotting.simulation import compute_dynamics - -from .read_load import clean_obs_names, merge, gene_info -from .read_load import convert_to_gene_names, convert_to_ensembl, load_biomart +__all__ = [ + "cleanup", + "clean_obs_names", + "clip", + "compute_dynamics", + "compute_velocity_on_grid", + "convert_to_adata", + "convert_to_ensembl", + "convert_to_gene_names", + "convert_to_loom", + "gene_info", + "get_cell_transitions", + "get_connectivities", + "get_initial_size", + "get_mean_var", + "get_moments", + "get_phase_marker_genes", + "get_transition_matrix", + "get_weight", + "interpret_colorkey", + "is_categorical", + "leastsq", + "load_biomart", + "merge", + "rgb_custom_colormap", + "set_initial_size", + "show_proportions", + "vals_to_csr", +] diff --git a/setup.py b/setup.py index e41d5a02..d75d6280 100644 --- a/setup.py +++ b/setup.py @@ -1,19 +1,25 @@ -from setuptools import setup, find_packages from pathlib import Path +from setuptools import find_packages, setup + + +def read_requirements(req_path): + """Read abstract requirements.""" + requirements = Path(req_path).read_text("utf-8").splitlines() + return [r.strip() for r in requirements if not r.startswith("-")] + + setup( name="scvelo", use_scm_version=True, setup_requires=["setuptools_scm"], python_requires=">=3.6", - install_requires=[ - l.strip() for l in Path("requirements.txt").read_text("utf-8").splitlines() - ], + install_requires=read_requirements("requirements.txt"), extras_require=dict( louvain=["python-igraph", "louvain"], hnswlib=["pybind11", "hnswlib"], - dev=["black==20.8b1", "pre-commit==2.5.1"], - docs=[r for r in Path("docs/requirements.txt").read_text("utf-8").splitlines()], + dev=read_requirements("requirements-dev.txt"), + docs=read_requirements("docs/requirements.txt"), ), packages=find_packages(), author="Volker Bergen", @@ -40,6 +46,7 @@ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Bio-Informatics", "Topic :: Scientific/Engineering :: Visualization", ], diff --git a/tests/test_basic.py b/tests/test_basic.py index dd950d92..bf837f97 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,9 +1,10 @@ -import scvelo as scv import numpy as np +import scvelo as scv + def test_einsum(): - from scvelo.tools.utils import prod_sum_obs, prod_sum_var, norm + from scvelo.tools.utils import norm, prod_sum_obs, prod_sum_var Ms, Mu = np.random.rand(5, 4), np.random.rand(5, 4) assert np.allclose(prod_sum_obs(Ms, Mu), np.sum(Ms * Mu, 0)) @@ -58,7 +59,7 @@ def test_pipeline(): Ms, Mu = adata.layers["Ms"][0], adata.layers["Mu"][0] Vs, Vd = adata.layers["velocity"][0], adata.layers["dynamical_velocity"][0] - Vpca, Vgraph = adata.obsm["velocity_pca"][0], adata.uns["velocity_graph"].data[:5] + Vgraph = adata.uns["velocity_graph"].data[:5] pars = adata[:, 0].var[["fit_alpha", "fit_gamma"]].values assert np.allclose(Ms, [0.8269, 1.0772, 0.9396, 1.0846, 1.0398], rtol=1e-2)