diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000..cb0de790
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,11 @@
+[flake8]
+max-line-length = 88
+filename = *.py
+exclude =
+ .git,
+ __pycache__,
+ docs/*
+# F403: unable to detect undefined names
+# F405: undefined, or defined from star imports
+# W503: Linebreak before binary operator
+ignore = F403, F405, W503
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index 586911b1..f7018e62 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -6,22 +6,31 @@ labels: bug
assignees: ''
---
-
+
...
-
+
+
```python
-...
+# paste your code here, if applicable
```
-
+
+
- Error
+
+ Error output
+
```pytb
-...
+# paste the error output here, if applicable
```
-#### Versions:
-
->
+
+
+ Versions
+
+```pytb
+# paste the ouput of scv.logging.print_versions() here
+```
+
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 72ea68ca..47a36de4 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -2,9 +2,9 @@ name: CI
on:
push:
- branches: [master, develop]
+ branches: [master, develop, 'release/**']
pull_request:
- branches: [master, develop]
+ branches: [master, develop, 'release/**']
jobs:
# Skip CI if commit message contains `[ci skip]` in the subject
@@ -19,7 +19,7 @@ jobs:
- id: ci-skip-step
uses: mstachniuk/ci-skip@master
- # Check if code agrees with `black` and if README.rst can be converted to HTML
+ # Check if pre-commit hooks pass and if README.rst can be converted to HTML
linting:
needs: init
if: ${{ needs.init.outputs.skip == 'false' }}
@@ -33,9 +33,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install black>=20.8b1 docutils
- - name: Check code style
- run: black --check --diff --color .
+ pip install pre-commit docutils
+ - name: Check pre-commit compatibility
+ run: pre-commit run --all-files --show-diff-on-failure
- name: Run rst2html.py
run: rst2html.py --halt=2 README.rst >/dev/null
@@ -55,6 +55,6 @@ jobs:
- name: Install dependencies
run: |
pip install -e .
- pip install pytest pytest-cov
+ pip install hypothesis pytest pytest-cov
- name: Unit tests
run: python -m pytest --cov=scvelo
diff --git a/.gitignore b/.gitignore
index 166a74fa..56a467b5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@ scripts/
write/
scanpy*/
benchmarking/
+htmlcov/
scvelo.egg-info/
@@ -37,4 +38,7 @@ docs/source/scvelo*
/dist/
.coverage
-.eggs
\ No newline at end of file
+.eggs
+
+# Files generated by unit tests
+.hypothesis/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 14aa0ee4..3c5e1dd8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -3,3 +3,18 @@ repos:
rev: 20.8b1
hooks:
- id: black
+- repo: https://gitlab.com/pycqa/flake8
+ rev: 3.8.4
+ hooks:
+ - id: flake8
+- repo: https://github.com/pycqa/isort
+ rev: 5.7.0
+ hooks:
+ - id: isort
+ name: isort (python)
+ - id: isort
+ name: isort (cython)
+ types: [cython]
+ - id: isort
+ name: isort (pyi)
+ types: [pyi]
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
new file mode 100644
index 00000000..06b0d014
--- /dev/null
+++ b/CONTRIBUTING.rst
@@ -0,0 +1,78 @@
+Contributing guide
+==================
+
+
+Getting started
+^^^^^^^^^^^^^^^
+
+Contributing to scVelo requires a developer installation. As a first step, we suggest creating a new environment
+
+.. code:: bash
+
+ conda create -n ENV_NAME python=PYTHON_VERSION && conda activate ENV_NAME
+
+
+Following, fork the scVelo repo on GitHub `here `.
+If you are unsure on how to do so, please checkout the corresponding
+`GitHub docs `.
+You can now clone your fork of scVelo and install the development mode
+
+.. code:: bash
+
+ git clone https://github.com/YOUR-USER-NAME/scvelo.git
+ cd scvelo
+ git checkout --track origin/develop
+ pip install -e '.[dev]'
+
+The last line can, alternatively, be replaced by
+
+.. code:: bash
+
+ pip install -r requirements-dev.txt
+
+
+Finally, to make sure your code follows our code style guideline, install pre-commit:
+
+.. code:: bash
+
+ pre-commit install
+
+
+Coding style
+^^^^^^^^^^^^
+
+Our code follows `black` and `flake8` coding style. Code formatting (`black`, `isort`) is automated through pre-commit hooks. In addition, we require that
+
+- functions are fully type-annotated.
+- variables referred to in an error/warning message or docstrings are enclosed in \`\`.
+
+
+Testing
+^^^^^^^
+
+To run the implemented unit tests locally, simply run
+
+.. code:: bash
+
+ python -m pytest
+
+
+Documentation
+^^^^^^^^^^^^^
+
+The docstrings of scVelo largely follow the `numpy`-style. New docstrings should
+
+- include neither type hints nor return types.
+- reference an argument within the same docstrings using \`\`.
+
+
+Submitting pull requests
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+New features and bug fixes are added to the code base through a pull request (PR). To implement a feature or bug fix, create a branch from `develop`. For hotfixes use `master` as base. The existence of bugs suggests insufficient test coverage. As such, bug fixes should, ideally, include a unit test or extend an existing one. Please ensure that
+
+- branch names have the prefix `feat/`, `bug/` or `hotfix/`.
+- your code follows the project conventions.
+- newly added functions are unit tested.
+- all tests pass locally.
+- if there is no issue solved by the PR, create one outlining what you try to add/solve and reference it in the PR description.
diff --git a/README.rst b/README.rst
index 46be45ab..e317e25b 100644
--- a/README.rst
+++ b/README.rst
@@ -34,6 +34,7 @@ patients and dynamic processes in human lung regeneration. Find out more in this
Latest news
^^^^^^^^^^^
+- Aug/2021: `Perspectives paper out in MSB `_
- Feb/2021: scVelo goes multi-core
- Dec/2020: Cover of `Nature Biotechnology `_
- Nov/2020: Talk at `Single Cell Biology `_
@@ -42,11 +43,15 @@ Latest news
- Sep/2020: Talk at `Single Cell Omics `_
- Aug/2020: `scVelo out in Nature Biotech `_
-Reference
-^^^^^^^^^
+References
+^^^^^^^^^^
+La Manno *et al.* (2018), RNA velocity of single cells, `Nature `_.
+
Bergen *et al.* (2020), Generalizing RNA velocity to transient cell states through dynamical modeling,
`Nature Biotech `_.
-|dim|
+
+Bergen *et al.* (2021), RNA velocity - current challenges and future perspectives,
+`Molecular Systems Biology `_.
Support
^^^^^^^
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 3f3b35b5..99078be6 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -9,5 +9,5 @@ sphinx_autodoc_typehints<=1.6
# converting notebooks to html
ipykernel
-sphinx>=1.7
-nbsphinx>=0.7
\ No newline at end of file
+sphinx>=1.7,<4.0
+nbsphinx>=0.7,<0.8.7
\ No newline at end of file
diff --git a/docs/source/_ext/edit_on_github.py b/docs/source/_ext/edit_on_github.py
index 720b69f1..6bd7852a 100644
--- a/docs/source/_ext/edit_on_github.py
+++ b/docs/source/_ext/edit_on_github.py
@@ -5,7 +5,6 @@
import os
import warnings
-
__licence__ = "BSD (3 clause)"
diff --git a/docs/source/api.rst b/docs/source/api.rst
index e3656a37..e4215c57 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -153,6 +153,12 @@ Datasets
datasets.pancreas
datasets.dentategyrus
datasets.forebrain
+ datasets.dentategyrus_lamanno
+ datasets.gastrulation
+ datasets.gastrulation_e75
+ datasets.gastrulation_erythroid
+ datasets.bonemarrow
+ datasets.pbmc68k
datasets.simulation
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 2ee05f04..64e7750b 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,15 +1,31 @@
-import sys
-import os
import inspect
import logging
-from pathlib import Path
+import os
+import sys
from datetime import datetime
-from typing import Optional, Union, Mapping
+from pathlib import Path, PurePosixPath
+from typing import Dict, List, Mapping, Optional, Tuple, Union
+from urllib.request import urlretrieve
+import sphinx_autodoc_typehints
+from docutils import nodes
+from jinja2.defaults import DEFAULT_FILTERS
+from sphinx import addnodes
from sphinx.application import Sphinx
+from sphinx.domains.python import PyObject, PyTypedField
+from sphinx.environment import BuildEnvironment
from sphinx.ext import autosummary
+import matplotlib # noqa
+
+HERE = Path(__file__).parent
+sys.path.insert(0, str(HERE.parent.parent))
+sys.path.insert(0, os.path.abspath("_ext"))
+
+import scvelo # isort:skip
+
# remove PyCharm’s old six module
+
if "six" in sys.modules:
print(*sys.path, sep="\n")
for pypath in list(sys.path):
@@ -17,34 +33,45 @@
sys.path.remove(pypath)
del sys.modules["six"]
-import matplotlib # noqa
-
matplotlib.use("agg")
-HERE = Path(__file__).parent
-sys.path.insert(0, f"{HERE.parent.parent}")
-sys.path.insert(0, os.path.abspath("_ext"))
-import scvelo
-
logger = logging.getLogger(__name__)
-# -- Retrieve notebooks ------------------------------------------------
-
-from urllib.request import urlretrieve
+# -- Basic notebooks and those stored under /vignettes and /perspectives --
notebooks_url = "https://github.com/theislab/scvelo_notebooks/raw/master/"
-notebooks = [
+notebooks = []
+notebook = [
"VelocityBasics.ipynb",
"DynamicalModeling.ipynb",
"DifferentialKinetics.ipynb",
+]
+notebooks.extend(notebook)
+
+notebook = [
"Pancreas.ipynb",
"DentateGyrus.ipynb",
+ "NatureBiotechCover.ipynb",
+ "Fig1_concept.ipynb",
+ "Fig2_dentategyrus.ipynb",
+ "Fig3_pancreas.ipynb",
+ "FigS9_runtime.ipynb",
+ "FigSuppl.ipynb",
]
+notebooks.extend([f"vignettes/{nb}" for nb in notebook])
+
+notebook = ["Perspectives.ipynb", "Perspectives_parameters.ipynb"]
+notebooks.extend([f"perspectives/{nb}" for nb in notebook])
+
+# -- Retrieve all notebooks --
+
for nb in notebooks:
+ url = notebooks_url + nb
try:
- urlretrieve(notebooks_url + nb, nb)
- except:
+ urlretrieve(url, nb)
+ except Exception as e:
+ logger.error(f"Unable to retrieve notebook: `{url}`. Reason: `{e}`")
pass
@@ -218,7 +245,6 @@ def get_linenos(obj):
github_url_read_loom = "https://github.com/theislab/anndata/tree/master/anndata"
github_url_read = "https://github.com/theislab/scanpy/tree/master"
github_url_scanpy = "https://github.com/theislab/scanpy/tree/master/scanpy"
-from pathlib import PurePosixPath
def modurl(qualname):
@@ -253,14 +279,11 @@ def api_image(qualname: str) -> Optional[str]:
# modify the default filters
-from jinja2.defaults import DEFAULT_FILTERS
DEFAULT_FILTERS.update(modurl=modurl, api_image=api_image)
# -- Override some classnames in autodoc --------------------------------------------
-import sphinx_autodoc_typehints
-
qualname_overrides = {
"anndata.base.AnnData": "anndata.AnnData",
"scvelo.pl.scatter": "scvelo.plotting.scatter",
@@ -292,12 +315,6 @@ def format_annotation(annotation):
# -- Prettier Param docs --------------------------------------------
-from typing import Dict, List, Tuple
-from docutils import nodes
-from sphinx import addnodes
-from sphinx.domains.python import PyTypedField, PyObject
-from sphinx.environment import BuildEnvironment
-
class PrettyTypedField(PyTypedField):
list_type = nodes.definition_list
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 8dc7eddf..372bf016 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -32,6 +32,7 @@ patients and dynamic processes in human lung regeneration. Find out more in this
Latest news
^^^^^^^^^^^
+- Aug/2021: `Perspectives paper out in MSB `_
- Feb/2021: scVelo goes multi-core
- Dec/2020: Cover of `Nature Biotechnology `_
- Nov/2020: Talk at `Single Cell Biology `_
@@ -40,11 +41,15 @@ Latest news
- Sep/2020: Talk at `Single Cell Omics `_
- Aug/2020: `scVelo out in Nature Biotech `_
-Reference
-^^^^^^^^^
+References
+^^^^^^^^^^
+La Manno *et al.* (2018), RNA velocity of single cells, `Nature `_.
+
Bergen *et al.* (2020), Generalizing RNA velocity to transient cell states through dynamical modeling,
`Nature Biotech `_.
-|dim|
+
+Bergen *et al.* (2021), RNA velocity - current challenges and future perspectives,
+`Molecular Systems Biology `_.
Support
^^^^^^^
@@ -77,14 +82,15 @@ For further information visit `scvelo.org `_.
VelocityBasics
DynamicalModeling
DifferentialKinetics
+ vignettes/index
.. toctree::
- :caption: Example Datasets
+ :caption: Perspectives
:maxdepth: 1
:hidden:
- Pancreas
- DentateGyrus
+ perspectives/index
+
.. |PyPI| image:: https://img.shields.io/pypi/v/scvelo.svg
:target: https://pypi.org/project/scvelo
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
index e1901192..d2245a8b 100644
--- a/docs/source/installation.rst
+++ b/docs/source/installation.rst
@@ -19,12 +19,13 @@ Development Version
To work with the latest development version, install from GitHub_ using::
- pip install git+https://github.com/theislab/scvelo
+ pip install git+https://github.com/theislab/scvelo@develop
or::
- git clone https://github.com/theislab/scvelo
- pip install -e scvelo
+ git clone https://github.com/theislab/scvelo && cd scvelo
+ git checkout --track origin/develop
+ pip install -e .
``-e`` is short for ``--editable`` and links the package to the original cloned
location such that pulled changes are also reflected in the environment.
diff --git a/docs/source/perspectives/index.rst b/docs/source/perspectives/index.rst
new file mode 100644
index 00000000..4460ec37
--- /dev/null
+++ b/docs/source/perspectives/index.rst
@@ -0,0 +1,34 @@
+Challenges and Perspectives
+---------------------------
+
+This page complements our manuscript
+`Bergen et al. (MSB, 2021) RNA velocity - Current challenges and future perspectives `_
+
+We provide several examples to discuss potential pitfalls of current RNA velocity
+modeling approaches, and provide guidance on how the ensuing challenges may be addressed.
+Our aspiration is to suggest promising future directions and to stimulate a communities effort on further model extensions.
+
+In the following, you find two vignettes with several use cases, as well as an in-depth analysis of time-dependent kinetic rate parameters.
+
+Potential pitfalls
+^^^^^^^^^^^^^^^^^^
+.. image:: https://user-images.githubusercontent.com/31883718/115840357-e354f480-a41b-11eb-8a95-12f7564fd9b0.png
+ :width: 300px
+ :align: left
+
+This notebook reproduces Fig. 2 with several use cases, including multiple kinetics in Dentate Gyrus,
+transcriptional boost in erythroid lineage, and misleading arrow projections in mature PBMCs.
+
+Notebook: `Perspectives `_
+
+|
+Kinetic parameter analysis
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+.. image:: https://user-images.githubusercontent.com/31883718/130656606-00bd44be-9071-4008-be1b-244fa9c2d244.png
+ :width: 300px
+ :align: left
+
+This notebook reproduces Fig. 3, where we demonstrate how time-variable kinetic rates
+shape the curvature patterns of gene activation.
+
+Notebook: `Kinetic parameter analysis `_
diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst
index a62bda49..5e812e86 100644
--- a/docs/source/release_notes.rst
+++ b/docs/source/release_notes.rst
@@ -4,6 +4,30 @@
Release Notes
=============
+Version 0.2.4 :small:`Aug 26, 2021`
+-----------------------------------
+
+Perspectives:
+
+- Landing page and two notebooks accompanying the perspectives manuscript at MSB.
+- New datasets: Gastrulation, bone marrow, and PBMCs.
+
+New capabilities:
+
+- Added vignettes accompanying the NBT manuscript.
+- Kinetic simulations with time-dependent rates.
+- New arguments for `tl.velocity_embedding_stream` (`PR 492 `_).
+- Introduced automated code formatting `flake8` and `isort` (`PR 360 `_, `PR 374 `_).
+- `tl.velocity_graph` parallelized (`PR 392 `_).
+- `legend_align_text` parameter in `pl.scatter` for smart placing of labels without overlapping.
+- Save option for `pl.proportions`.
+
+Bugfixes:
+
+- Pinned `sphinx<4.0` and `nbsphinx<0.8.7`.
+- Fix IPython import at CLI.
+
+
Version 0.2.3 :small:`Feb 13, 2021`
-----------------------------------
@@ -140,4 +164,4 @@ Version 0.1.5 :small:`Sep 4, 2018`
Version 0.1.2 :small:`Aug 21, 2018`
-----------------------------------
-First alpha release of scvelo.
\ No newline at end of file
+First alpha release of scvelo.
diff --git a/docs/source/vignettes/index.rst b/docs/source/vignettes/index.rst
new file mode 100644
index 00000000..87f6735e
--- /dev/null
+++ b/docs/source/vignettes/index.rst
@@ -0,0 +1,19 @@
+Other Vignettes
+---------------
+
+Example Datasets
+^^^^^^^^^^^^^^^^
+- `Dentate Gyrus `_
+- `Pancreas `_
+
+Nature Biotech Figures
+^^^^^^^^^^^^^^^^^^^^^^
+
+- NBT Cover | `cover `__ | `notebook `__
+- Fig.1 Concept | `figure `__ | `notebook `__
+- Fig.2 Dentate Gyrus | `figure `__ | `notebook `__
+- Fig.3 Pancreas | `figure `__ | `notebook `__
+- Suppl. Figures | `figure `__ | `notebook `__ | `runtime `__
+
+All notebooks are deposited at `GitHub `_.
+Found a bug? Feel free to submit an `issue `_.
\ No newline at end of file
diff --git a/pypi.rst b/pypi.rst
index b2833154..9170fed4 100644
--- a/pypi.rst
+++ b/pypi.rst
@@ -32,6 +32,7 @@ patients and dynamic processes in human lung regeneration. Find out more in this
Latest news
^^^^^^^^^^^
+- Aug/2021: `Perspectives paper out in MSB `_
- Feb/2021: scVelo goes multi-core
- Dec/2020: Cover of `Nature Biotechnology `_
- Nov/2020: Talk at `Single Cell Biology `_
@@ -40,11 +41,16 @@ Latest news
- Sep/2020: Talk at `Single Cell Omics `_
- Aug/2020: `scVelo out in Nature Biotech `_
-Reference
-^^^^^^^^^
+References
+^^^^^^^^^^
+La Manno *et al.* (2018), RNA velocity of single cells, `Nature `_.
+
Bergen *et al.* (2020), Generalizing RNA velocity to transient cell states through dynamical modeling,
`Nature Biotech `_.
+Bergen *et al.* (2021), RNA velocity - current challenges and future perspectives,
+`Molecular Systems Biology `_.
+
Support
^^^^^^^
Found a bug or would like to see a feature implemented? Feel free to submit an
diff --git a/pyproject.toml b/pyproject.toml
index 381b4762..cbf07ea1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,3 +17,17 @@ exclude = '''
| dist
)/
'''
+
+[tool.isort]
+profile = "black"
+use_parentheses = true
+known_num = "networkx,numpy,pandas,scipy,sklearn,statmodels"
+known_plot = "matplotlib,mpl_toolkits,seaborn"
+known_bio = "anndata,scanpy"
+sections = "FUTURE,STDLIB,THIRDPARTY,NUM,PLOT,BIO,FIRSTPARTY,LOCALFOLDER"
+no_lines_before = "LOCALFOLDER"
+balanced_wrapping = true
+length_sort = "0"
+indent = " "
+float_to_top = true
+order_by_type = false
diff --git a/pytest.ini b/pytest.ini
index a12918d6..67a1c148 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,4 +1,6 @@
[pytest]
-python_files = *.py
-testpaths = tests
+python_files = test_*.py
+testpaths =
+ tests
+ scvelo
xfail_strict = true
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 1bd89b6c..af20ddc2 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -3,4 +3,12 @@
-e .
black==20.8b1
-pre-commit==2.5.1
+hnswlib
+hypothesis
+flake8==3.8.4
+isort==5.7.0
+louvain
+pre-commit>=2.9.0
+pybind11
+pytest>=6.2.2
+python-igraph
diff --git a/requirements.txt b/requirements.txt
index 5831c94e..c1ad4bc3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,5 @@
+typing_extensions
+
# main requirements for single-cell analysis
anndata>=0.7.5 # compatible with h5py v3.0 (v0.7.5)
scanpy>=1.5 # adapt to new anndata attributes (v1.5)
@@ -6,11 +8,12 @@ scanpy>=1.5 # adapt to new anndata attributes (v1.5)
loompy>=2.0.12 # introduced sparsity support (v2.0.12)
# for computing neighbor graph connectivities
-umap-learn>=0.3.10, <0.5 # removed numba warnings (v0.3.10)
+umap-learn>=0.3.10 # removed numba warnings (v0.3.10)
# standard requirements for data analysis
+numba>=0.41.0
numpy>=1.17 # extension/speedup in .nan_to_num, .exp (v1.17)
scipy>=1.4.1 # introduced PCA sparsity support (v1.4)
pandas>=0.23 # merging/sorting extensions (v0.23)
scikit-learn>=0.21.2 # bugfix in .utils.sparsefuncs (v0.21.2)
-matplotlib>=3.1.2 # several bugfixes (v3.1.2)
\ No newline at end of file
+matplotlib>=3.3.0 # normalize in pie (v3.3.0)
\ No newline at end of file
diff --git a/scvelo/__init__.py b/scvelo/__init__.py
index 6df0f0ac..76cc70d1 100644
--- a/scvelo/__init__.py
+++ b/scvelo/__init__.py
@@ -1,4 +1,17 @@
"""scvelo - RNA velocity generalized through dynamical modeling"""
+from anndata import AnnData
+from scanpy import read, read_loom
+
+from scvelo import datasets, logging, pl, pp, settings, tl, utils
+from scvelo.core import get_df
+from scvelo.plotting.gridspec import GridSpec
+from scvelo.preprocessing.neighbors import Neighbors
+from scvelo.read_load import DataFrame, load, read_csv
+from scvelo.settings import set_figure_params
+from scvelo.tools.run import run_all, test
+from scvelo.tools.utils import round
+from scvelo.tools.velocity import Velocity
+from scvelo.tools.velocity_graph import VelocityGraph
try:
from setuptools_scm import get_version
@@ -8,24 +21,33 @@
except (LookupError, ImportError):
try:
from importlib_metadata import version # Python < 3.8
- except:
+ except Exception:
from importlib.metadata import version # Python = 3.8
__version__ = version(__name__)
del version
-from .read_load import AnnData, read, read_loom, load, read_csv, get_df, DataFrame
-from .preprocessing.neighbors import Neighbors
-from .tools.run import run_all, test
-from .tools.utils import round
-from .tools.velocity import Velocity
-from .tools.velocity_graph import VelocityGraph
-from .plotting.gridspec import GridSpec
-from .settings import set_figure_params
-from . import pp
-from . import tl
-from . import pl
-from . import utils
-from . import datasets
-from . import logging
-from . import settings
+__all__ = [
+ "AnnData",
+ "DataFrame",
+ "datasets",
+ "get_df",
+ "GridSpec",
+ "load",
+ "logging",
+ "Neighbors",
+ "pl",
+ "pp",
+ "read",
+ "read_csv",
+ "read_loom",
+ "round",
+ "run_all",
+ "set_figure_params",
+ "settings",
+ "test",
+ "tl",
+ "utils",
+ "Velocity",
+ "VelocityGraph",
+]
diff --git a/scvelo/core/__init__.py b/scvelo/core/__init__.py
new file mode 100644
index 00000000..156eeb6e
--- /dev/null
+++ b/scvelo/core/__init__.py
@@ -0,0 +1,43 @@
+from ._anndata import (
+ clean_obs_names,
+ cleanup,
+ get_df,
+ get_initial_size,
+ get_modality,
+ get_size,
+ make_dense,
+ make_sparse,
+ merge,
+ set_initial_size,
+ set_modality,
+ show_proportions,
+)
+from ._arithmetic import clipped_log, invert, prod_sum, sum
+from ._linear_models import LinearRegression
+from ._metrics import l2_norm
+from ._models import SplicingDynamics
+from ._parallelize import get_n_jobs, parallelize
+
+__all__ = [
+ "clean_obs_names",
+ "cleanup",
+ "clipped_log",
+ "get_df",
+ "get_initial_size",
+ "get_modality",
+ "get_n_jobs",
+ "get_size",
+ "invert",
+ "l2_norm",
+ "LinearRegression",
+ "make_dense",
+ "make_sparse",
+ "merge",
+ "parallelize",
+ "prod_sum",
+ "set_initial_size",
+ "set_modality",
+ "show_proportions",
+ "SplicingDynamics",
+ "sum",
+]
diff --git a/scvelo/core/_anndata.py b/scvelo/core/_anndata.py
new file mode 100644
index 00000000..d77bbb7d
--- /dev/null
+++ b/scvelo/core/_anndata.py
@@ -0,0 +1,790 @@
+import re
+from typing import List, Optional, Union
+
+from typing_extensions import Literal
+
+import numpy as np
+import pandas as pd
+from numpy import ndarray
+from pandas import DataFrame
+from pandas.api.types import is_categorical_dtype
+from scipy.sparse import csr_matrix, issparse, spmatrix
+
+from anndata import AnnData
+
+from scvelo import logging as logg
+from ._arithmetic import sum
+
+
+def clean_obs_names(
+ data: AnnData,
+ base: str = "[AGTCBDHKMNRSVWY]",
+ ID_length: int = 12,
+ copy: bool = False,
+) -> Optional[AnnData]:
+ """Clean up the obs_names.
+
+ For example an obs_name 'sample1_AGTCdate' is changed to 'AGTC' of the sample
+ 'sample1_date'. The sample name is then saved in obs['sample_batch'].
+ The genetic codes are identified according to according to
+ https://www.neb.com/tools-and-resources/usage-guidelines/the-genetic-code.
+
+ Arguments
+ ---------
+ data
+ Annotated data matrix.
+ base
+ Genetic code letters to be identified.
+ ID_length
+ Length of the Genetic Codes in the samples.
+ copy
+ Return a copy instead of writing to adata.
+
+ Returns
+ -------
+ Optional[AnnData]
+ Returns or updates `adata` with the attributes
+ obs_names: list
+ updated names of the observations
+ sample_batch: `.obs`
+ names of the identified sample batches
+ """
+
+ def get_base_list(name, base):
+ base_list = base
+ while re.search(base_list + base, name) is not None:
+ base_list += base
+ if len(base_list) == 0:
+ raise ValueError("Encountered an invalid ID in obs_names: ", name)
+ return base_list
+
+ adata = data.copy() if copy else data
+
+ names = adata.obs_names
+ base_list = get_base_list(names[0], base)
+
+ if len(np.unique([len(name) for name in adata.obs_names])) == 1:
+ start, end = re.search(base_list, names[0]).span()
+ newIDs = [name[start:end] for name in names]
+ start, end = 0, len(newIDs[0])
+ for i in range(end - ID_length):
+ if np.any([ID[i] not in base for ID in newIDs]):
+ start += 1
+ if np.any([ID[::-1][i] not in base for ID in newIDs]):
+ end -= 1
+
+ newIDs = [ID[start:end] for ID in newIDs]
+ prefixes = [names[i].replace(newIDs[i], "") for i in range(len(names))]
+ else:
+ prefixes, newIDs = [], []
+ for name in names:
+ match = re.search(base_list, name)
+ newID = (
+ re.search(get_base_list(name, base), name).group()
+ if match is None
+ else match.group()
+ )
+ newIDs.append(newID)
+ prefixes.append(name.replace(newID, ""))
+
+ adata.obs_names = newIDs
+ if len(prefixes[0]) > 0 and len(np.unique(prefixes)) > 1:
+ adata.obs["sample_batch"] = (
+ pd.Categorical(prefixes)
+ if len(np.unique(prefixes)) < adata.n_obs
+ else prefixes
+ )
+
+ adata.obs_names_make_unique()
+ return adata if copy else None
+
+
+def cleanup(
+ data: AnnData,
+ clean: Union[
+ Literal["layers", "obs", "var", "uns"],
+ List[Literal["layers", "obs", "var", "uns"]],
+ ] = "layers",
+ keep: Optional[Union[str, List[str]]] = None,
+ copy: bool = False,
+) -> Optional[AnnData]:
+ """Delete not needed attributes.
+
+ Arguments
+ ---------
+ data
+ Annotated data matrix.
+ clean
+ Which attributes to consider for freeing memory.
+ keep
+ Which attributes to keep.
+ copy
+ Return a copy instead of writing to adata.
+
+ Returns
+ -------
+ Optional[AnnData]
+ Returns or updates `adata` with selection of attributes kept.
+ """
+
+ adata = data.copy() if copy else data
+ verify_dtypes(adata)
+
+ keep = list([keep] if isinstance(keep, str) else {} if keep is None else keep)
+ keep.extend(["spliced", "unspliced", "Ms", "Mu", "clusters", "neighbors"])
+
+ ann_dict = {
+ "obs": adata.obs_keys(),
+ "var": adata.var_keys(),
+ "uns": adata.uns_keys(),
+ "layers": list(adata.layers.keys()),
+ }
+
+ if "all" not in clean:
+ ann_dict = {ann: values for (ann, values) in ann_dict.items() if ann in clean}
+
+ for (ann, values) in ann_dict.items():
+ for value in values:
+ if value not in keep:
+ del getattr(adata, ann)[value]
+
+ return adata if copy else None
+
+
+def get_df(
+ data: AnnData,
+ keys: Optional[Union[str, List[str]]] = None,
+ layer: Optional[str] = None,
+ index: List = None,
+ columns: List = None,
+ sort_values: bool = None,
+ dropna: Literal["all", "any"] = "all",
+ precision: int = None,
+) -> DataFrame:
+ """Get dataframe for a specified adata key.
+
+ Return values for specified key
+ (in obs, var, obsm, varm, obsp, varp, uns, or layers) as a dataframe.
+
+ Arguments
+ ---------
+ data
+ AnnData object or a numpy array to get values from.
+ keys
+ Keys from `.var_names`, `.obs_names`, `.var`, `.obs`,
+ `.obsm`, `.varm`, `.obsp`, `.varp`, `.uns`, or `.layers`.
+ layer
+ Layer of `adata` to use as expression values.
+ index
+ List to set as index.
+ columns
+ List to set as columns names.
+ sort_values
+ Wether to sort values by first column (sort_values=True) or a specified column.
+ dropna
+ Drop columns/rows that contain NaNs in all ('all') or in any entry ('any').
+ precision
+ Set precision for pandas dataframe.
+
+ Returns
+ -------
+ :class:`pd.DataFrame`
+ A dataframe.
+ """
+
+ if precision is not None:
+ pd.set_option("precision", precision)
+
+ if isinstance(data, AnnData):
+ keys, keys_split = (
+ keys.split("*") if isinstance(keys, str) and "*" in keys else (keys, None)
+ )
+ keys, key_add = (
+ keys.split("/") if isinstance(keys, str) and "/" in keys else (keys, None)
+ )
+ keys = [keys] if isinstance(keys, str) else keys
+ key = keys[0]
+
+ s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"]
+ d_keys = [
+ data.obs.keys(),
+ data.var.keys(),
+ data.obsm.keys(),
+ data.varm.keys(),
+ data.uns.keys(),
+ data.layers.keys(),
+ ]
+
+ if hasattr(data, "obsp") and hasattr(data, "varp"):
+ s_keys.extend(["obsp", "varp"])
+ d_keys.extend([data.obsp.keys(), data.varp.keys()])
+
+ if keys is None:
+ df = data.to_df()
+ elif key in data.var_names:
+ df = obs_df(data, keys, layer=layer)
+ elif key in data.obs_names:
+ df = var_df(data, keys, layer=layer)
+ else:
+ if keys_split is not None:
+ keys = [
+ k
+ for k in list(data.obs.keys()) + list(data.var.keys())
+ if key in k and keys_split in k
+ ]
+ key = keys[0]
+ s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
+ if len(s_key) == 0:
+ raise ValueError(f"'{key}' not found in any of {', '.join(s_keys)}.")
+ if len(s_key) > 1:
+ logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.")
+
+ s_key = s_key[-1]
+ df = getattr(data, s_key)[keys if len(keys) > 1 else key]
+ if key_add is not None:
+ df = df[key_add]
+ if index is None:
+ index = (
+ data.var_names
+ if s_key == "varm"
+ else data.obs_names
+ if s_key in {"obsm", "layers"}
+ else None
+ )
+ if index is None and s_key == "uns" and hasattr(df, "shape"):
+ key_cats = np.array(
+ [
+ key
+ for key in data.obs.keys()
+ if is_categorical_dtype(data.obs[key])
+ ]
+ )
+ num_cats = [
+ len(data.obs[key].cat.categories) == df.shape[0]
+ for key in key_cats
+ ]
+ if np.sum(num_cats) == 1:
+ index = data.obs[key_cats[num_cats][0]].cat.categories
+ if (
+ columns is None
+ and len(df.shape) > 1
+ and df.shape[0] == df.shape[1]
+ ):
+ columns = index
+ elif isinstance(index, str) and index in data.obs.keys():
+ index = pd.Categorical(data.obs[index]).categories
+ if columns is None and s_key == "layers":
+ columns = data.var_names
+ elif isinstance(columns, str) and columns in data.obs.keys():
+ columns = pd.Categorical(data.obs[columns]).categories
+ elif isinstance(data, pd.DataFrame):
+ if isinstance(keys, str) and "*" in keys:
+ keys, keys_split = keys.split("*")
+ keys = [k for k in data.columns if keys in k and keys_split in k]
+ df = data[keys] if keys is not None else data
+ else:
+ df = data
+
+ if issparse(df):
+ df = np.array(df.A)
+ if columns is None and hasattr(df, "names"):
+ columns = df.names
+
+ df = pd.DataFrame(df, index=index, columns=columns)
+
+ if dropna:
+ df.replace("", np.nan, inplace=True)
+ how = dropna if isinstance(dropna, str) else "any" if dropna is True else "all"
+ df.dropna(how=how, axis=0, inplace=True)
+ df.dropna(how=how, axis=1, inplace=True)
+
+ if sort_values:
+ sort_by = (
+ sort_values
+ if isinstance(sort_values, str) and sort_values in df.columns
+ else df.columns[0]
+ )
+ df = df.sort_values(by=sort_by, ascending=False)
+
+ if hasattr(data, "var_names"):
+ if df.index[0] in data.var_names:
+ df.var_names = df.index
+ elif df.columns[0] in data.var_names:
+ df.var_names = df.columns
+ if hasattr(data, "obs_names"):
+ if df.index[0] in data.obs_names:
+ df.obs_names = df.index
+ elif df.columns[0] in data.obs_names:
+ df.obs_names = df.columns
+
+ return df
+
+
+# TODO: Generalize to arbitrary modality
+def get_initial_size(
+ adata: AnnData, layer: Optional[str] = None, by_total_size: bool = False
+) -> Optional[ndarray]:
+ """Get initial counts per observation of a layer.
+
+ Arguments
+ ---------
+ adata
+ Annotated data matrix.
+ layer
+ Name of layer for which to retrieve initial size.
+ by_total_size
+ Whether or not to return the combined initial size of the spliced and unspliced
+ layers.
+
+ Returns
+ -------
+ np.ndarray
+ Initial counts per observation in the specified layer.
+ """
+
+ if by_total_size:
+ sizes = [
+ adata.obs[f"initial_size_{layer}"]
+ for layer in {"spliced", "unspliced"}
+ if f"initial_size_{layer}" in adata.obs.keys()
+ ]
+ return np.sum(sizes, axis=0)
+ elif layer in adata.layers.keys():
+ return (
+ np.array(adata.obs[f"initial_size_{layer}"])
+ if f"initial_size_{layer}" in adata.obs.keys()
+ else get_size(adata, layer)
+ )
+ elif layer is None or layer == "X":
+ return (
+ np.array(adata.obs["initial_size"])
+ if "initial_size" in adata.obs.keys()
+ else get_size(adata)
+ )
+ else:
+ return None
+
+
+def get_modality(adata: AnnData, modality: str) -> Union[ndarray, spmatrix]:
+ """Extract data of one modality.
+
+ Arguments
+ ---------
+ adata
+ Annotated data to extract modality from.
+ modality
+ Modality for which data is needed.
+
+ Returns
+ -------
+ Union[ndarray, spmatrix]
+ Retrieved modality from :class:`~anndata.AnnData` object.
+ """
+
+ if modality == "X":
+ return adata.X
+ elif modality in adata.layers.keys():
+ return adata.layers[modality]
+ elif modality in adata.obsm.keys():
+ if isinstance(adata.obsm[modality], DataFrame):
+ return adata.obsm[modality].values
+ else:
+ return adata.obsm[modality]
+
+
+# TODO: Generalize to arbitray modality
+def get_size(adata: AnnData, layer: Optional[str] = None) -> ndarray:
+ """Get counts per observation in a layer.
+
+ Arguments
+ ---------
+ adata
+ Annotated data matrix.
+ layer
+ Name of later for which to retrieve initial size.
+
+ Returns
+ -------
+ np.ndarray
+ Initial counts per observation in the specified layer.
+ """
+
+ X = adata.X if layer is None else adata.layers[layer]
+ return sum(X, axis=1)
+
+
+def make_dense(
+ adata: AnnData, modalities: Union[List[str], str], inplace: bool = True
+) -> Optional[AnnData]:
+ """Densify sparse AnnData entry.
+
+ Arguments
+ ---------
+ adata
+ Annotated data object.
+ modality
+ Modality to make dense.
+ inplace
+ Boolean flag to perform operations inplace or not. Defaults to `True`.
+
+ Returns
+ -------
+ Optional[AnnData]
+ Copy of annotated data `adata` if `inplace=True` with dense modalities.
+ """
+
+ if not inplace:
+ adata = adata.copy()
+
+ if isinstance(modalities, str):
+ modalities = [modalities]
+
+ # Densify modalities
+ for modality in modalities:
+ count_data = get_modality(adata=adata, modality=modality)
+ if issparse(count_data):
+ set_modality(adata=adata, modality=modality, new_value=count_data.A)
+
+ return adata if not inplace else None
+
+
+# TODO: Allow choosing format of sparse matrix i.e., csr, csc, ...
+def make_sparse(
+ adata: AnnData, modalities: Union[List[str], str], inplace: bool = True
+) -> Optional[AnnData]:
+ """Make AnnData entry sparse.
+
+ Arguments
+ ---------
+ adata
+ Annotated data object.
+ modality
+ Modality to make sparse.
+ inplace
+ Boolean flag to perform operations inplace or not. Defaults to `True`.
+
+ Returns
+ -------
+ Optional[AnnData]
+ Copy of annotated data `adata` with sparse modalities if `inplace=True`.
+ """
+
+ if not inplace:
+ adata = adata.copy()
+
+ if isinstance(modalities, str):
+ modalities = [modalities]
+
+ # Make modalities sparse
+ for modality in modalities:
+ count_data = get_modality(adata=adata, modality=modality)
+ if modality == "X":
+ logg.warn("Making `X` sparse is not supported.")
+ elif not issparse(count_data):
+ set_modality(
+ adata=adata, modality=modality, new_value=csr_matrix(count_data)
+ )
+
+ return adata if not inplace else None
+
+
+def merge(adata: AnnData, ldata: AnnData, copy: bool = True) -> Optional[AnnData]:
+ """Merge two annotated data matrices.
+
+ Arguments
+ ---------
+ adata
+ Annotated data matrix (reference data set).
+ ldata
+ Annotated data matrix (to be merged into adata).
+ copy
+ Boolean flag to manipulate original AnnData or a copy of it.
+
+ Returns
+ -------
+ Optional[:class:`anndata.AnnData`]
+ Returns a :class:`~anndata.AnnData` object
+ """
+
+ adata.var_names_make_unique()
+ ldata.var_names_make_unique()
+
+ if (
+ "spliced" in ldata.layers.keys()
+ and "initial_size_spliced" not in ldata.obs.keys()
+ ):
+ set_initial_size(ldata)
+ elif (
+ "spliced" in adata.layers.keys()
+ and "initial_size_spliced" not in adata.obs.keys()
+ ):
+ set_initial_size(adata)
+
+ common_obs = pd.unique(adata.obs_names.intersection(ldata.obs_names))
+ common_vars = pd.unique(adata.var_names.intersection(ldata.var_names))
+
+ if len(common_obs) == 0:
+ clean_obs_names(adata)
+ clean_obs_names(ldata)
+ common_obs = adata.obs_names.intersection(ldata.obs_names)
+
+ if copy:
+ _adata = adata[common_obs].copy()
+ _ldata = ldata[common_obs].copy()
+ else:
+ adata._inplace_subset_obs(common_obs)
+ _adata, _ldata = adata, ldata[common_obs].copy()
+
+ _adata.var_names_make_unique()
+ _ldata.var_names_make_unique()
+
+ same_vars = len(_adata.var_names) == len(_ldata.var_names) and np.all(
+ _adata.var_names == _ldata.var_names
+ )
+ join_vars = len(common_vars) > 0
+
+ if join_vars and not same_vars:
+ _adata._inplace_subset_var(common_vars)
+ _ldata._inplace_subset_var(common_vars)
+
+ for attr in _ldata.obs.keys():
+ if attr not in _adata.obs.keys():
+ _adata.obs[attr] = _ldata.obs[attr]
+ for attr in _ldata.obsm.keys():
+ if attr not in _adata.obsm.keys():
+ _adata.obsm[attr] = _ldata.obsm[attr]
+ for attr in _ldata.uns.keys():
+ if attr not in _adata.uns.keys():
+ _adata.uns[attr] = _ldata.uns[attr]
+ if join_vars:
+ for attr in _ldata.layers.keys():
+ if attr not in _adata.layers.keys():
+ _adata.layers[attr] = _ldata.layers[attr]
+
+ if _adata.shape[1] == _ldata.shape[1]:
+ same_vars = len(_adata.var_names) == len(_ldata.var_names) and np.all(
+ _adata.var_names == _ldata.var_names
+ )
+ if same_vars:
+ for attr in _ldata.var.keys():
+ if attr not in _adata.var.keys():
+ _adata.var[attr] = _ldata.var[attr]
+ for attr in _ldata.varm.keys():
+ if attr not in _adata.varm.keys():
+ _adata.varm[attr] = _ldata.varm[attr]
+ else:
+ raise ValueError("Variable names are not identical.")
+
+ return _adata if copy else None
+
+
+def obs_df(adata: AnnData, keys: List[str], layer: Optional[str] = None) -> DataFrame:
+ """Extract layer as Pandas DataFrame indexed by observation.
+
+ Arguments
+ ---------
+ adata
+ Annotated data matrix (reference data set).
+ keys
+ Variables for which to extract data.
+ layer
+ Name of layer to turn into a Pandas DataFrame.
+
+ Returns
+ -------
+ DataFrame
+ DataFrame indexed by observations. Columns correspond to variables of specified
+ layer.
+ """
+
+ lookup_keys = [k for k in keys if k in adata.var_names]
+ if len(lookup_keys) < len(keys):
+ logg.warn(
+ f"Keys {[k for k in keys if k not in adata.var_names]} "
+ f"were not found in `adata.var_names`."
+ )
+
+ df = pd.DataFrame(index=adata.obs_names)
+ for lookup_key in lookup_keys:
+ df[lookup_key] = adata.obs_vector(lookup_key, layer=layer)
+ return df
+
+
+# TODO: Generalize to arbitrary modality
+def set_initial_size(adata: AnnData, layers: Optional[str] = None) -> None:
+ """Set current counts per observation of a layer as its initial size.
+
+ The initial size is only set if it does not already exist.
+
+ Arguments
+ ---------
+ adata
+ Annotated data matrix.
+ layers
+ Name of layers for which to calculate initial size.
+
+ Returns
+ -------
+ None
+ """
+
+ if layers is None:
+ layers = ["spliced", "unspliced"]
+ verify_dtypes(adata)
+ layers = [
+ layer
+ for layer in layers
+ if layer in adata.layers.keys()
+ and f"initial_size_{layer}" not in adata.obs.keys()
+ ]
+ for layer in layers:
+ adata.obs[f"initial_size_{layer}"] = get_size(adata, layer)
+ if "initial_size" not in adata.obs.keys():
+ adata.obs["initial_size"] = get_size(adata)
+
+
+def set_modality(
+ adata: AnnData,
+ new_value: Union[ndarray, spmatrix, DataFrame],
+ modality: Optional[str] = None,
+ inplace: bool = True,
+) -> Optional[AnnData]:
+ """Set modality of annotated data object to new value.
+
+ Arguments
+ ---------
+ adata
+ Annotated data object.
+ new_value
+ New value of modality.
+ modality
+ Modality to overwrite with new value. Defaults to `None`.
+ inplace
+ Boolean flag to indicate whether setting of modality should be inplace or
+ not. Defaults to `True`.
+
+ Returns
+ -------
+ Optional[AnnData]
+ Copy of annotated data `adata` with updated modality if `inplace=True`.
+ """
+
+ if not inplace:
+ adata = adata.copy()
+
+ if (modality == "X") or (modality is None):
+ adata.X = new_value
+ elif modality in adata.layers.keys():
+ adata.layers[modality] = new_value
+ elif modality in adata.obsm.keys():
+ adata.obsm[modality] = new_value
+
+ if not inplace:
+ return adata
+
+
+def show_proportions(
+ adata: AnnData, layers: Optional[str] = None, use_raw: bool = True
+) -> None:
+ """Proportions of abundances of modalities in layers.
+
+ The proportions are printed.
+
+ Arguments
+ ---------
+ adata
+ Annotated data matrix.
+ layers
+ Layers to consider.
+ use_raw
+ Use initial sizes, i.e., raw data, to determine proportions.
+
+ Returns
+ -------
+ None
+ """
+
+ if layers is None:
+ layers = ["spliced", "unspliced", "ambigious"]
+ layers_keys = [key for key in layers if key in adata.layers.keys()]
+ counts_layers = [sum(adata.layers[key], axis=1) for key in layers_keys]
+ if use_raw:
+ size_key, obs = "initial_size_", adata.obs
+ counts_layers = [
+ obs[size_key + layer] if size_key + layer in obs.keys() else c
+ for layer, c in zip(layers_keys, counts_layers)
+ ]
+
+ counts_per_cell_sum = np.sum(counts_layers, 0)
+ counts_per_cell_sum += counts_per_cell_sum == 0
+
+ mean_abundances = [
+ np.mean(counts_per_cell / counts_per_cell_sum)
+ for counts_per_cell in counts_layers
+ ]
+
+ print(f"Abundance of {layers_keys}: {np.round(mean_abundances, 2)}")
+
+
+def var_df(adata: AnnData, keys: List[str], layer: Optional[str] = None):
+ """Extract layer as Pandas DataFrame indexed by features.
+
+ Arguments
+ ---------
+ adata
+ Annotated data matrix (reference data set).
+ keys
+ Observations for which to extract data.
+ layer
+ Name of layer to turn into a Pandas DataFrame.
+
+ Returns
+ -------
+ DataFrame
+ DataFrame indexed by features. Columns correspond to observations of specified
+ layer.
+ """
+
+ lookup_keys = [k for k in keys if k in adata.obs_names]
+ if len(lookup_keys) < len(keys):
+ logg.warn(
+ f"Keys {[k for k in keys if k not in adata.obs_names]} "
+ f"were not found in `adata.obs_names`."
+ )
+
+ df = pd.DataFrame(index=adata.var_names)
+ for lookup_key in lookup_keys:
+ df[lookup_key] = adata.var_vector(lookup_key, layer=layer)
+ return df
+
+
+# TODO: Find better function name
+def verify_dtypes(adata: AnnData) -> None:
+ """Verify that AnnData object is not corrupted.
+
+ Arguments
+ ---------
+ adata
+ Annotated data matrix to check.
+
+ Returns
+ -------
+ None
+ """
+
+ try:
+ _ = adata[:, 0]
+ except Exception:
+ uns = adata.uns
+ adata.uns = {}
+ try:
+ _ = adata[:, 0]
+ logg.warn(
+ "Safely deleted unstructured annotations (adata.uns), \n"
+ "as these do not comply with permissible anndata datatypes."
+ )
+ except Exception:
+ logg.warn(
+ "The data might be corrupted. Please verify all annotation datatypes."
+ )
+ adata.uns = uns
diff --git a/scvelo/core/_arithmetic.py b/scvelo/core/_arithmetic.py
new file mode 100644
index 00000000..037de5f2
--- /dev/null
+++ b/scvelo/core/_arithmetic.py
@@ -0,0 +1,103 @@
+import warnings
+from typing import Optional, Union
+
+import numpy as np
+from numpy import ndarray
+from scipy.sparse import issparse, spmatrix
+
+
+def clipped_log(x: ndarray, lb: float = 0, ub: float = 1, eps: float = 1e-6) -> ndarray:
+ """Logarithmize between [lb + epsilon, ub - epsilon].
+
+ Arguments
+ ---------
+ x
+ Array to invert.
+ lb
+ Lower bound of interval to which array entries are clipped.
+ ub
+ Upper bound of interval to which array entries are clipped.
+ eps
+ Offset of boundaries of clipping interval.
+
+ Returns
+ -------
+ ndarray
+ Logarithm of clipped array.
+ """
+
+ return np.log(np.clip(x, lb + eps, ub - eps))
+
+
+def invert(x: ndarray) -> ndarray:
+ """Invert array and set infinity to NaN.
+
+ Arguments
+ ---------
+ x
+ Array to invert.
+
+ Returns
+ -------
+ ndarray
+ Inverted array.
+ """
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ x_inv = 1 / x * (x != 0)
+ return x_inv
+
+
+def prod_sum(
+ a1: Union[ndarray, spmatrix], a2: Union[ndarray, spmatrix], axis: Optional[int]
+) -> ndarray:
+ """Take sum of product of two arrays along given axis.
+
+ Arguments
+ ---------
+ a1
+ First array.
+ a2
+ Second array.
+ axis
+ Axis along which to sum elements. If `None`, all elements will be summed.
+ Defaults to `None`.
+
+ Returns
+ -------
+ ndarray
+ Sum of product of arrays along given axis.
+ """
+
+ if issparse(a1):
+ return a1.multiply(a2).sum(axis=axis).A1
+ elif axis == 0:
+ return np.einsum("ij, ij -> j", a1, a2) if a1.ndim > 1 else (a1 * a2).sum()
+ elif axis == 1:
+ return np.einsum("ij, ij -> i", a1, a2) if a1.ndim > 1 else (a1 * a2).sum()
+
+
+def sum(a: Union[ndarray, spmatrix], axis: Optional[int] = None) -> ndarray:
+ """Sum array elements over a given axis.
+
+ Arguments
+ ---------
+ a
+ Elements to sum.
+ axis
+ Axis along which to sum elements. If `None`, all elements will be summed.
+ Defaults to `None`.
+
+ Returns
+ -------
+ ndarray
+ Sum of array along given axis.
+ """
+
+ if a.ndim == 1:
+ axis = 0
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ return a.sum(axis=axis).A1 if issparse(a) else a.sum(axis=axis)
diff --git a/scvelo/core/_base.py b/scvelo/core/_base.py
new file mode 100644
index 00000000..4990eaaa
--- /dev/null
+++ b/scvelo/core/_base.py
@@ -0,0 +1,54 @@
+from abc import ABC, abstractmethod
+from typing import Dict, Tuple, Union
+
+from numpy import ndarray
+
+
+class DynamicsBase(ABC):
+ @abstractmethod
+ def get_solution(
+ self, t: ndarray, stacked: True, with_keys: bool = False
+ ) -> Union[Dict, Tuple[ndarray], ndarray]:
+ """Calculate solution of dynamics.
+
+ Arguments
+ ---------
+ t
+ Time steps at which to evaluate solution.
+ stacked
+ Whether to stack states or return them individually. Defaults to `True`.
+ with_keys
+ Whether to return solution labelled by variables in form of a dictionary.
+ Defaults to `False`.
+
+ Returns
+ -------
+ Union[Dict, Tuple[ndarray], ndarray]
+ Solution of system. If `with_keys=True`, the solution is returned in form of
+ a dictionary with variables as keys. Otherwise, the solution is given as
+ a `numpy.ndarray` of form `(n_steps, n_vars)`.
+ """
+
+ return
+
+ @abstractmethod
+ def get_steady_states(
+ self, stacked: True, with_keys: False
+ ) -> Union[Dict[str, ndarray], Tuple[ndarray], ndarray]:
+ """Return steady state of system.
+
+ Arguments
+ ---------
+ stacked
+ Whether to stack states or return them individually. Defaults to `True`.
+ with_keys
+ Whether to return solution labelled by variables in form of a dictionary.
+ Defaults to `False`.
+
+ Returns
+ -------
+ Union[Dict[str, ndarray], Tuple[ndarray], ndarray]
+ Steady state of system.
+ """
+
+ return
diff --git a/scvelo/core/_linear_models.py b/scvelo/core/_linear_models.py
new file mode 100644
index 00000000..89796761
--- /dev/null
+++ b/scvelo/core/_linear_models.py
@@ -0,0 +1,150 @@
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+from numpy import ndarray
+from scipy.sparse import csr_matrix, issparse
+
+from ._arithmetic import prod_sum, sum
+
+
+class LinearRegression:
+ """Extreme quantile and constraint least square linear regression.
+
+ Arguments
+ ---------
+ percentile
+ Percentile of data on which linear regression line is fit. If `None`, all data
+ is used, if a single value is given, it is interpreted as the upper quantile.
+ Defaults to `None`.
+ fit_intercept
+ Whether to calculate the intercept for model. Defaults to `False`.
+ positive_intercept
+ Whether the intercept it constraint to positive values. Only plays a role when
+ `fit_intercept=True`. Defaults to `True`.
+ constrain_ratio
+ Ratio to which coefficients are clipped. If `None`, the coefficients are not
+ constraint. Defaults to `None`.
+
+ Attributes
+ ----------
+ coef_
+ Estimated coefficients of the linear regression line.
+
+ intercept_
+ Fitted intercept of linear model. Set to `0.0` if `fit_intercept=False`.
+
+ """
+
+ def __init__(
+ self,
+ percentile: Optional[Union[Tuple, int, float]] = None,
+ fit_intercept: bool = False,
+ positive_intercept: bool = True,
+ constrain_ratio: Optional[Union[Tuple, float]] = None,
+ ):
+ if not fit_intercept and isinstance(percentile, (list, tuple)):
+ self.percentile = percentile[1]
+ else:
+ self.percentile = percentile
+ self.fit_intercept = fit_intercept
+ self.positive_intercept = positive_intercept
+
+ if constrain_ratio is None:
+ self.constrain_ratio = [-np.inf, np.inf]
+ elif len(constrain_ratio) == 1:
+ self.constrain_ratio = [-np.inf, constrain_ratio]
+ else:
+ self.constrain_ratio = constrain_ratio
+
+ def _trim_data(self, data: List) -> List:
+ """Trim data to extreme values.
+
+ Arguments
+ ---------
+ data
+ Data to be trimmed to extreme quantiles.
+
+ Returns
+ -------
+ List
+ Number of non-trivial entries per column and trimmed data.
+ """
+
+ if not isinstance(data, List):
+ data = [data]
+
+ data = np.array(
+ [data_mat.A if issparse(data_mat) else data_mat for data_mat in data]
+ )
+
+ # TODO: Add explanatory comment
+ normalized_data = np.sum(
+ data / data.max(axis=1, keepdims=True).clip(1e-3, None), axis=0
+ )
+
+ bound = np.percentile(normalized_data, self.percentile, axis=0)
+
+ if bound.ndim == 1:
+ trimmer = csr_matrix(normalized_data >= bound).astype(bool)
+ else:
+ trimmer = csr_matrix(
+ (normalized_data <= bound[0]) | (normalized_data >= bound[1])
+ ).astype(bool)
+
+ return [trimmer.getnnz(axis=0)] + [
+ trimmer.multiply(data_mat).tocsr() for data_mat in data
+ ]
+
+ def fit(self, x: ndarray, y: ndarray):
+ """Fit linear model per column.
+
+ Arguments
+ ---------
+ x
+ Training data of shape `(n_obs, n_vars)`.
+ y
+ Target values of shape `(n_obs, n_vars)`.
+
+ Returns
+ -------
+ self
+ Returns an instance of self.
+ """
+
+ n_obs = x.shape[0]
+
+ if self.percentile is not None:
+ n_obs, x, y = self._trim_data(data=[x, y])
+
+ _xx = prod_sum(x, x, axis=0)
+ _xy = prod_sum(x, y, axis=0)
+
+ if self.fit_intercept:
+ _x = sum(x, axis=0) / n_obs
+ _y = sum(y, axis=0) / n_obs
+ self.coef_ = (_xy / n_obs - _x * _y) / (_xx / n_obs - _x ** 2)
+ self.intercept_ = _y - self.coef_ * _x
+
+ if self.positive_intercept:
+ idx = self.intercept_ < 0
+ if self.coef_.ndim > 0:
+ self.coef_[idx] = _xy[idx] / _xx[idx]
+ else:
+ self.coef_ = _xy / _xx
+ self.intercept_ = np.clip(self.intercept_, 0, None)
+ else:
+ self.coef_ = _xy / _xx
+ self.intercept_ = np.zeros(x.shape[1]) if x.ndim > 1 else 0
+
+ if not np.isscalar(self.coef_):
+ self.coef_[np.isnan(self.coef_)] = 0
+ self.intercept_[np.isnan(self.intercept_)] = 0
+ else:
+ if np.isnan(self.coef_):
+ self.coef_ = 0
+ if np.isnan(self.intercept_):
+ self.intercept_ = 0
+
+ self.coef_ = np.clip(self.coef_, *self.constrain_ratio)
+
+ return self
diff --git a/scvelo/core/_metrics.py b/scvelo/core/_metrics.py
new file mode 100644
index 00000000..539e4489
--- /dev/null
+++ b/scvelo/core/_metrics.py
@@ -0,0 +1,32 @@
+from typing import Union
+
+import numpy as np
+from numpy import ndarray
+from scipy.sparse import issparse, spmatrix
+
+
+# TODO: Add case `axis == None`
+def l2_norm(x: Union[ndarray, spmatrix], axis: int = 1) -> Union[float, ndarray]:
+ """Calculate l2 norm along a given axis.
+
+ Arguments
+ ---------
+ x
+ Array to calculate l2 norm of.
+ axis
+ Axis along which to calculate l2 norm.
+
+ Returns
+ -------
+ Union[float, ndarray]
+ L2 norm along a given axis.
+ """
+
+ if issparse(x):
+ return np.sqrt(x.multiply(x).sum(axis=axis).A1)
+ elif x.ndim == 1:
+ return np.sqrt(np.einsum("i, i -> ", x, x))
+ elif axis == 0:
+ return np.sqrt(np.einsum("ij, ij -> j", x, x))
+ elif axis == 1:
+ return np.sqrt(np.einsum("ij, ij -> i", x, x))
diff --git a/scvelo/core/_models.py b/scvelo/core/_models.py
new file mode 100644
index 00000000..c38c6688
--- /dev/null
+++ b/scvelo/core/_models.py
@@ -0,0 +1,143 @@
+from typing import Dict, List, Tuple, Union
+
+import numpy as np
+from numpy import ndarray
+
+from ._arithmetic import invert
+from ._base import DynamicsBase
+
+
+# TODO: Improve parameter names: alpha -> transcription_rate; beta -> splicing_rate;
+# gamma -> degradation_rate
+# TODO: Handle cases beta = 0, gamma == 0, beta == gamma
+class SplicingDynamics(DynamicsBase):
+ """Splicing dynamics.
+
+ Arguments
+ ---------
+ alpha
+ Transcription rate.
+ beta
+ Translation rate.
+ gamma
+ Splicing degradation rate.
+ initial_state
+ Initial state of system. Defaults to `[0, 0]`.
+
+ Attributes
+ ----------
+ alpha
+ Transcription rate.
+ beta
+ Translation rate.
+ gamma
+ Splicing degradation rate.
+ initial_state
+ Initial state of system. Defaults to `[0, 0]`.
+ u0
+ Initial abundance of unspliced RNA.
+ s0
+ Initial abundance of spliced RNA.
+
+ """
+
+ def __init__(
+ self,
+ alpha: float,
+ beta: float,
+ gamma: float,
+ initial_state: Union[List, ndarray] = [0, 0],
+ ):
+
+ self.alpha = alpha
+ self.beta = beta
+ self.gamma = gamma
+
+ self.initial_state = initial_state
+
+ @property
+ def initial_state(self):
+ return self._initial_state
+
+ @initial_state.setter
+ def initial_state(self, val):
+ if isinstance(val, list) or (isinstance(val, ndarray) and (val.ndim == 1)):
+ self.u0 = val[0]
+ self.s0 = val[1]
+ else:
+ self.u0 = val[:, 0]
+ self.s0 = val[:, 1]
+ self._initial_state = val
+
+ def get_solution(
+ self, t: ndarray, stacked: bool = True, with_keys: bool = False
+ ) -> Union[Dict, ndarray]:
+ """Calculate solution of dynamics.
+
+ Arguments
+ ---------
+ t
+ Time steps at which to evaluate solution.
+ stacked
+ Whether to stack states or return them individually. Defaults to `True`.
+ with_keys
+ Whether to return solution labelled by variables in form of a dictionary.
+ Defaults to `False`.
+
+ Returns
+ -------
+ Union[Dict, ndarray]
+ Solution of system. If `with_keys=True`, the solution is returned in form of
+ a dictionary with variables as keys. Otherwise, the solution is given as
+ a `numpy.ndarray` of form `(n_steps, 2)`.
+ """
+
+ expu = np.exp(-self.beta * t)
+ exps = np.exp(-self.gamma * t)
+
+ unspliced = self.u0 * expu + self.alpha / self.beta * (1 - expu)
+ c = (self.alpha - self.u0 * self.beta) * invert(self.gamma - self.beta)
+ spliced = (
+ self.s0 * exps + self.alpha / self.gamma * (1 - exps) + c * (exps - expu)
+ )
+
+ if with_keys:
+ return {"u": unspliced, "s": spliced}
+ elif not stacked:
+ return unspliced, spliced
+ else:
+ if isinstance(t, np.ndarray) and t.ndim == 2:
+ return np.stack([unspliced, spliced], axis=2)
+ else:
+ return np.column_stack([unspliced, spliced])
+
+ # TODO: Handle cases `beta = 0`, `gamma = 0`
+ def get_steady_states(
+ self, stacked: bool = True, with_keys: bool = False
+ ) -> Union[Dict[str, ndarray], Tuple[ndarray], ndarray]:
+ """Return steady state of system.
+
+ Arguments
+ ---------
+ stacked
+ Whether to stack states or return them individually. Defaults to `True`.
+ with_keys
+ Whether to return solution labelled by variables in form of a dictionary.
+ Defaults to `False`.
+
+ Returns
+ -------
+ Union[Dict[str, ndarray], Tuple[ndarray], ndarray]
+ Steady state of system.
+ """
+
+ if (self.beta > 0) and (self.gamma > 0):
+ unspliced = self.alpha / self.beta
+ spliced = self.alpha / self.gamma
+
+ if with_keys:
+ return {"u": unspliced, "s": spliced}
+ elif not stacked:
+ return unspliced, spliced
+ else:
+ return np.array([unspliced, spliced])
diff --git a/scvelo/core/_parallelize.py b/scvelo/core/_parallelize.py
new file mode 100644
index 00000000..85b7120f
--- /dev/null
+++ b/scvelo/core/_parallelize.py
@@ -0,0 +1,169 @@
+import os
+from multiprocessing import Manager
+from threading import Thread
+from typing import Any, Callable, Optional, Sequence, Union
+
+from joblib import delayed, Parallel
+
+import numpy as np
+from scipy.sparse import issparse, spmatrix
+
+from scvelo import logging as logg
+
+_msg_shown = False
+
+
+def get_n_jobs(n_jobs):
+ if n_jobs is None or (n_jobs < 0 and os.cpu_count() + 1 + n_jobs <= 0):
+ return 1
+ elif n_jobs > os.cpu_count():
+ return os.cpu_count()
+ elif n_jobs < 0:
+ return os.cpu_count() + 1 + n_jobs
+ else:
+ return n_jobs
+
+
+def parallelize(
+ callback: Callable[[Any], Any],
+ collection: Union[spmatrix, Sequence[Any]],
+ n_jobs: Optional[int] = None,
+ n_split: Optional[int] = None,
+ unit: str = "",
+ as_array: bool = True,
+ use_ixs: bool = False,
+ backend: str = "loky",
+ extractor: Optional[Callable[[Any], Any]] = None,
+ show_progress_bar: bool = True,
+) -> Union[np.ndarray, Any]:
+ """
+ Parallelize function call over a collection of elements.
+
+ Parameters
+ ----------
+ callback
+ Function to parallelize.
+ collection
+ Sequence of items which to chunkify.
+ n_jobs
+ Number of parallel jobs.
+ n_split
+ Split :paramref:`collection` into :paramref:`n_split` chunks.
+ If `None`, split into :paramref:`n_jobs` chunks.
+ unit
+ Unit of the progress bar.
+ as_array
+ Whether to convert the results not :class:`numpy.ndarray`.
+ use_ixs
+ Whether to pass indices to the callback.
+ backend
+ Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid
+ options.
+ extractor
+ Function to apply to the result after all jobs have finished.
+ show_progress_bar
+ Whether to show a progress bar.
+
+ Returns
+ -------
+ :class:`numpy.ndarray`
+ Result depending on :paramref:`extractor` and :paramref:`as_array`.
+ """
+
+ if show_progress_bar:
+ try:
+ try:
+ from tqdm.notebook import tqdm
+ except ImportError:
+ from tqdm import tqdm_notebook as tqdm
+ import ipywidgets # noqa
+ except ImportError:
+ global _msg_shown
+ tqdm = None
+
+ if not _msg_shown:
+ logg.warn(
+ "Unable to create progress bar. "
+ "Consider installing `tqdm` as `pip install tqdm` "
+ "and `ipywidgets` as `pip install ipywidgets`,\n"
+ "or disable the progress bar using `show_progress_bar=False`."
+ )
+ _msg_shown = True
+ else:
+ tqdm = None
+
+ def update(pbar, queue, n_total):
+ n_finished = 0
+ while n_finished < n_total:
+ try:
+ res = queue.get()
+ except EOFError as e:
+ if not n_finished != n_total:
+ raise RuntimeError(
+ f"Finished only `{n_finished} out of `{n_total}` tasks.`"
+ ) from e
+ break
+ assert res in (None, (1, None), 1) # (None, 1) means only 1 job
+ if res == (1, None):
+ n_finished += 1
+ if pbar is not None:
+ pbar.update()
+ elif res is None:
+ n_finished += 1
+ elif pbar is not None:
+ pbar.update()
+
+ if pbar is not None:
+ pbar.close()
+
+ def wrapper(*args, **kwargs):
+ if pass_queue and show_progress_bar:
+ pbar = None if tqdm is None else tqdm(total=col_len, unit=unit)
+ queue = Manager().Queue()
+ thread = Thread(target=update, args=(pbar, queue, len(collections)))
+ thread.start()
+ else:
+ pbar, queue, thread = None, None, None
+
+ res = Parallel(n_jobs=n_jobs, backend=backend)(
+ delayed(callback)(
+ *((i, cs) if use_ixs else (cs,)),
+ *args,
+ **kwargs,
+ queue=queue,
+ )
+ for i, cs in enumerate(collections)
+ )
+
+ res = np.array(res) if as_array else res
+ if thread is not None:
+ thread.join()
+
+ return res if extractor is None else extractor(res)
+
+ col_len = collection.shape[0] if issparse(collection) else len(collection)
+
+ if n_split is None:
+ n_split = get_n_jobs(n_jobs=n_jobs)
+
+ if issparse(collection):
+ if n_split == collection.shape[0]:
+ collections = [collection[[ix], :] for ix in range(collection.shape[0])]
+ else:
+ step = collection.shape[0] // n_split
+
+ ixs = [
+ np.arange(i * step, min((i + 1) * step, collection.shape[0]))
+ for i in range(n_split)
+ ]
+ ixs[-1] = np.append(
+ ixs[-1], np.arange(ixs[-1][-1] + 1, collection.shape[0])
+ )
+
+ collections = [collection[ix, :] for ix in filter(len, ixs)]
+ else:
+ collections = list(filter(len, np.array_split(collection, n_split)))
+
+ pass_queue = not hasattr(callback, "py_func") # we'd be inside a numba function
+
+ return wrapper
diff --git a/scvelo/core/tests/__init__.py b/scvelo/core/tests/__init__.py
new file mode 100644
index 00000000..6d4c3dfd
--- /dev/null
+++ b/scvelo/core/tests/__init__.py
@@ -0,0 +1,3 @@
+from .test_base import get_adata, TestBase
+
+__all__ = ["get_adata", "TestBase"]
diff --git a/scvelo/core/tests/test_anndata.py b/scvelo/core/tests/test_anndata.py
new file mode 100644
index 00000000..b73cc941
--- /dev/null
+++ b/scvelo/core/tests/test_anndata.py
@@ -0,0 +1,136 @@
+import hypothesis.strategies as st
+from hypothesis import given
+
+import numpy as np
+from numpy.testing import assert_array_equal
+from scipy.sparse import issparse
+
+from anndata import AnnData
+
+from scvelo.core import get_modality, make_dense, make_sparse, set_modality
+from .test_base import get_adata, TestBase
+
+
+class TestGetModality(TestBase):
+ @given(adata=get_adata())
+ def test_get_modality(self, adata: AnnData):
+ modality_to_get = self._subset_modalities(adata, 1)[0]
+ modality_retrieved = get_modality(adata=adata, modality=modality_to_get)
+
+ if modality_to_get == "X":
+ assert_array_equal(adata.X, modality_retrieved)
+ elif modality_to_get in adata.layers:
+ assert_array_equal(adata.layers[modality_to_get], modality_retrieved)
+ else:
+ assert_array_equal(adata.obsm[modality_to_get], modality_retrieved)
+
+
+class TestMakeDense(TestBase):
+ @given(
+ adata=get_adata(sparse_entries=True),
+ inplace=st.booleans(),
+ n_modalities=st.integers(min_value=0),
+ )
+ def test_make_dense(self, adata: AnnData, inplace: bool, n_modalities: int):
+ modalities_to_densify = self._subset_modalities(adata, n_modalities)
+
+ returned_adata = make_dense(
+ adata=adata, modalities=modalities_to_densify, inplace=inplace
+ )
+
+ if inplace:
+ assert returned_adata is None
+ assert np.all(
+ [
+ not issparse(get_modality(adata=adata, modality=modality))
+ for modality in modalities_to_densify
+ ]
+ )
+ else:
+ assert isinstance(returned_adata, AnnData)
+ assert np.all(
+ [
+ not issparse(get_modality(adata=returned_adata, modality=modality))
+ for modality in modalities_to_densify
+ ]
+ )
+ assert np.all(
+ [
+ issparse(get_modality(adata=adata, modality=modality))
+ for modality in modalities_to_densify
+ ]
+ )
+
+
+class TestMakeSparse(TestBase):
+ @given(
+ adata=get_adata(),
+ inplace=st.booleans(),
+ n_modalities=st.integers(min_value=0),
+ )
+ def test_make_sparse(self, adata: AnnData, inplace: bool, n_modalities: int):
+ modalities_to_make_sparse = self._subset_modalities(adata, n_modalities)
+
+ returned_adata = make_sparse(
+ adata=adata, modalities=modalities_to_make_sparse, inplace=inplace
+ )
+
+ if inplace:
+ assert returned_adata is None
+ assert np.all(
+ [
+ issparse(get_modality(adata=adata, modality=modality))
+ for modality in modalities_to_make_sparse
+ if modality != "X"
+ ]
+ )
+ else:
+ assert isinstance(returned_adata, AnnData)
+ assert np.all(
+ [
+ issparse(get_modality(adata=returned_adata, modality=modality))
+ for modality in modalities_to_make_sparse
+ if modality != "X"
+ ]
+ )
+ assert np.all(
+ [
+ not issparse(get_modality(adata=adata, modality=modality))
+ for modality in modalities_to_make_sparse
+ if modality != "X"
+ ]
+ )
+
+
+class TestSetModality(TestBase):
+ @given(adata=get_adata(), inplace=st.booleans())
+ def test_set_modality(self, adata: AnnData, inplace: bool):
+ modality_to_set = self._subset_modalities(adata, 1)[0]
+
+ if (modality_to_set == "X") or (modality_to_set in adata.layers):
+ new_value = np.random.randn(adata.n_obs, adata.n_vars)
+ else:
+ new_value = np.random.randn(
+ adata.n_obs, np.random.randint(low=1, high=10000)
+ )
+
+ returned_adata = set_modality(
+ adata=adata, new_value=new_value, modality=modality_to_set, inplace=inplace
+ )
+
+ if inplace:
+ assert returned_adata is None
+ if modality_to_set == "X":
+ assert_array_equal(adata.X, new_value)
+ elif modality_to_set in adata.layers:
+ assert_array_equal(adata.layers[modality_to_set], new_value)
+ else:
+ assert_array_equal(adata.obsm[modality_to_set], new_value)
+ else:
+ assert isinstance(returned_adata, AnnData)
+ if modality_to_set == "X":
+ assert_array_equal(returned_adata.X, new_value)
+ elif modality_to_set in adata.layers:
+ assert_array_equal(returned_adata.layers[modality_to_set], new_value)
+ else:
+ assert_array_equal(returned_adata.obsm[modality_to_set], new_value)
diff --git a/scvelo/core/tests/test_arithmetic.py b/scvelo/core/tests/test_arithmetic.py
new file mode 100644
index 00000000..d8f12092
--- /dev/null
+++ b/scvelo/core/tests/test_arithmetic.py
@@ -0,0 +1,197 @@
+from typing import List
+
+from hypothesis import given
+from hypothesis import strategies as st
+from hypothesis.extra.numpy import arrays
+
+import numpy as np
+from numpy import ndarray
+from numpy.testing import assert_almost_equal, assert_array_equal
+
+from scvelo.core import clipped_log, invert, prod_sum, sum
+
+
+class TestClippedLog:
+ @given(
+ a=arrays(
+ float,
+ shape=st.integers(min_value=1, max_value=100),
+ elements=st.floats(
+ min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False
+ ),
+ ),
+ bounds=st.lists(
+ st.floats(
+ min_value=0, max_value=100, allow_infinity=False, allow_nan=False
+ ),
+ min_size=2,
+ max_size=2,
+ unique=True,
+ ),
+ eps=st.floats(
+ min_value=1e-6, max_value=1, allow_infinity=False, allow_nan=False
+ ),
+ )
+ def test_flat_arrays(self, a: ndarray, bounds: List[float], eps: float):
+ lb = min(bounds)
+ ub = max(bounds) + 2 * eps
+
+ a_logged = clipped_log(a, lb=lb, ub=ub, eps=eps)
+
+ assert a_logged.shape == a.shape
+ if (a <= lb).any():
+ assert_almost_equal(np.abs(a_logged - np.log(lb + eps)).min(), 0)
+ else:
+ assert (a_logged >= np.log(lb + eps)).all()
+ if (a >= ub).any():
+ assert_almost_equal(np.abs(a_logged - np.log(ub - eps)).min(), 0)
+ else:
+ assert (a_logged <= np.log(ub - eps)).all()
+
+ @given(
+ a=arrays(
+ float,
+ shape=st.tuples(
+ st.integers(min_value=1, max_value=100),
+ st.integers(min_value=1, max_value=100),
+ ),
+ elements=st.floats(
+ min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False
+ ),
+ ),
+ bounds=st.lists(
+ st.floats(
+ min_value=0, max_value=100, allow_infinity=False, allow_nan=False
+ ),
+ min_size=2,
+ max_size=2,
+ unique=True,
+ ),
+ eps=st.floats(
+ min_value=1e-6, max_value=1, allow_infinity=False, allow_nan=False
+ ),
+ )
+ def test_2d_arrays(self, a: ndarray, bounds: List[float], eps: float):
+ lb = min(bounds)
+ ub = max(bounds) + 2 * eps
+
+ a_logged = clipped_log(a, lb=lb, ub=ub, eps=eps)
+
+ assert a_logged.shape == a.shape
+ if (a <= lb).any():
+ assert_almost_equal(np.abs(a_logged - np.log(lb + eps)).min(), 0)
+ else:
+ assert (a_logged >= np.log(lb + eps)).all()
+ if (a >= ub).any():
+ assert_almost_equal(np.abs(a_logged - np.log(ub - eps)).min(), 0)
+ else:
+ assert (a_logged <= np.log(ub - eps)).all()
+
+
+class TestInvert:
+ @given(
+ a=arrays(
+ float,
+ shape=st.integers(min_value=1, max_value=100),
+ elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
+ )
+ )
+ def test_flat_arrays(self, a: ndarray):
+ a_inv = invert(a)
+
+ if a[a != 0].size == 0:
+ assert a_inv[a != 0].size == 0
+ else:
+ assert_array_equal(a_inv[a != 0], 1 / a[a != 0])
+
+ if 0 in a:
+ assert np.isnan(a_inv[a == 0]).all()
+ else:
+ assert set(a_inv[a == 0]) == set()
+
+ @given(
+ a=arrays(
+ float,
+ shape=st.tuples(
+ st.integers(min_value=1, max_value=100),
+ st.integers(min_value=1, max_value=100),
+ ),
+ elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
+ )
+ )
+ def test_2d_arrays(self, a: ndarray):
+ a_inv = invert(a)
+
+ if a[a != 0].size == 0:
+ assert a_inv[a != 0].size == 0
+ else:
+ assert_array_equal(a_inv[a != 0], 1 / a[a != 0])
+
+ if 0 in a:
+ assert np.isnan(a_inv[a == 0]).all()
+ else:
+ assert set(a_inv[a == 0]) == set()
+
+
+# TODO: Extend test to generate sparse inputs as well
+# TODO: Make test to generate two different arrays a1, a2
+# TODO: Check why tests fail with assert_almost_equal
+class TestProdSum:
+ @given(
+ a=arrays(
+ float,
+ shape=st.integers(min_value=1, max_value=100),
+ elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
+ ),
+ axis=st.integers(min_value=0, max_value=1),
+ )
+ def test_flat_array(self, a: ndarray, axis: int):
+ assert np.allclose((a * a).sum(axis=0), prod_sum(a, a, axis=axis))
+
+ @given(
+ a=arrays(
+ float,
+ shape=st.tuples(
+ st.integers(min_value=1, max_value=100),
+ st.integers(min_value=1, max_value=100),
+ ),
+ elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
+ ),
+ axis=st.integers(min_value=0, max_value=1),
+ )
+ def test_2d_array(self, a: ndarray, axis: int):
+ assert np.allclose((a * a).sum(axis=axis), prod_sum(a, a, axis=axis))
+
+
+# TODO: Extend test to generate sparse inputs as well
+class TestSum:
+ @given(
+ a=arrays(
+ float,
+ shape=st.integers(min_value=1, max_value=100),
+ elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
+ ),
+ )
+ def test_flat_arrays(self, a: ndarray):
+ a_summed = sum(a=a, axis=0)
+
+ assert_array_equal(a_summed, a.sum(axis=0))
+
+ @given(
+ a=arrays(
+ float,
+ shape=st.tuples(
+ st.integers(min_value=1, max_value=100),
+ st.integers(min_value=1, max_value=100),
+ ),
+ elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
+ ),
+ axis=st.integers(min_value=0, max_value=1),
+ )
+ def test_2d_arrays(self, a: ndarray, axis: int):
+ a_summed = sum(a=a, axis=axis)
+
+ if a.ndim == 1:
+ axis = 0
+
+ assert_array_equal(a_summed, a.sum(axis=axis))
diff --git a/scvelo/core/tests/test_base.py b/scvelo/core/tests/test_base.py
new file mode 100644
index 00000000..13095a52
--- /dev/null
+++ b/scvelo/core/tests/test_base.py
@@ -0,0 +1,196 @@
+import random
+from typing import List, Optional, Union
+
+import hypothesis.strategies as st
+from hypothesis import given
+from hypothesis.extra.numpy import arrays
+
+import numpy as np
+from scipy.sparse import csr_matrix, issparse
+
+from anndata import AnnData
+
+
+# TODO: Add possibility to generate adata object with floats as counts
+@st.composite
+def get_adata(
+ draw,
+ n_obs: Optional[int] = None,
+ n_vars: Optional[int] = None,
+ min_obs: Optional[int] = 1,
+ max_obs: Optional[int] = 100,
+ min_vars: Optional[int] = 1,
+ max_vars: Optional[int] = 100,
+ layer_keys: Optional[Union[List, str]] = None,
+ min_layers: Optional[int] = 2,
+ max_layers: int = 2,
+ obsm_keys: Optional[Union[List, str]] = None,
+ min_obsm: Optional[int] = 2,
+ max_obsm: Optional[int] = 2,
+ sparse_entries: bool = False,
+) -> AnnData:
+ """Generate an AnnData object.
+
+ The largest possible value of a numerical entry is `1e5`.
+
+ Arguments
+ ---------
+ n_obs:
+ Number of observations. If set to `None`, a random integer between `1` and
+ `max_obs` will be drawn. Defaults to `None`.
+ n_vars:
+ Number of variables. If set to `None`, a random integer between `1` and
+ `max_vars` will be drawn. Defaults to `None`.
+ min_obs:
+ Minimum number of observations. If set to `None`, there is no lower limit.
+ Defaults to `1`.
+ max_obs:
+ Maximum number of observations. If set to `None`, there is no upper limit.
+ Defaults to `100`.
+ min_vars:
+ Minimum number of variables. If set to `None`, there is no lower limit.
+ Defaults to `1`.
+ max_vars:
+ Maximum number of variables. If set to `None`, there is no upper limit.
+ Defaults to `100`.
+ layer_keys:
+ Names of layers. If set to `None`, layers will be named at random. Defaults
+ to `None`.
+ min_layers:
+ Minimum number of layers. Is set to the number of provided layer names if
+ `layer_keys` is not `None`. Defaults to `2`.
+ max_layers: Maximum number of layers. Is set to the number of provided layer
+ names if `layer_keys` is not `None`. Defaults to `2`.
+ obsm_keys:
+ Names of multi-dimensional observations annotation. If set to `None`, names
+ will be generated at random. Defaults to `None`.
+ min_obsm:
+ Minimum number of multi-dimensional observations annotation. Is set to the
+ number of keys if `obsm_keys` is not `None`. Defaults to `2`.
+ max_obsm:
+ Maximum number of multi-dimensional observations annotation. Is set to the
+ number of keys if `obsm_keys` is not `None`. Defaults to `2`.
+ sparse_entries:
+ Whether or not to make AnnData entries sparse.
+
+ Returns
+ -------
+ AnnData
+ Generated :class:`~anndata.AnnData` object.
+ """
+
+ if n_obs is None:
+ n_obs = draw(st.integers(min_value=min_obs, max_value=max_obs))
+ if n_vars is None:
+ n_vars = draw(st.integers(min_value=min_vars, max_value=max_vars))
+
+ if isinstance(layer_keys, str):
+ layer_keys = [layer_keys]
+ if isinstance(obsm_keys, str):
+ obsm_keys = [obsm_keys]
+
+ if layer_keys is not None:
+ min_layers = len(layer_keys)
+ max_layers = len(layer_keys)
+ if obsm_keys is not None:
+ min_obsm = len(obsm_keys)
+ max_obsm = len(obsm_keys)
+
+ X = draw(
+ arrays(
+ dtype=int,
+ elements=st.integers(min_value=0, max_value=1e2),
+ shape=(n_obs, n_vars),
+ )
+ )
+
+ layers = draw(
+ st.dictionaries(
+ st.text(min_size=1) if layer_keys is None else st.sampled_from(layer_keys),
+ arrays(
+ dtype=int,
+ elements=st.integers(min_value=0, max_value=1e2),
+ shape=(n_obs, n_vars),
+ ),
+ min_size=min_layers,
+ max_size=max_layers,
+ )
+ )
+
+ obsm = draw(
+ st.dictionaries(
+ st.text(min_size=1) if obsm_keys is None else st.sampled_from(obsm_keys),
+ arrays(
+ dtype=int,
+ elements=st.integers(min_value=0, max_value=1e2),
+ shape=st.tuples(
+ st.integers(min_value=n_obs, max_value=n_obs),
+ st.integers(min_value=min_vars, max_value=max_vars),
+ ),
+ ),
+ min_size=min_obsm,
+ max_size=max_obsm,
+ )
+ )
+
+ # Make keys for layers and obsm unique
+ for key in set(layers.keys()).intersection(obsm.keys()):
+ layers[f"{key}_"] = layers.pop(key)
+
+ if sparse_entries:
+ layers = {key: csr_matrix(val) for key, val in layers.items()}
+ obsm = {key: csr_matrix(val) for key, val in obsm.items()}
+ return AnnData(X=csr_matrix(X), layers=layers, obsm=obsm)
+ else:
+ return AnnData(X=X, layers=layers, obsm=obsm)
+
+
+class TestAdataGeneration:
+ @given(adata=get_adata())
+ def test_default_adata_generation(self, adata: AnnData):
+ assert type(adata) is AnnData
+
+ @given(adata=get_adata(sparse_entries=True))
+ def test_sparse_adata_generation(self, adata: AnnData):
+ assert type(adata) is AnnData
+ assert issparse(adata.X)
+ assert np.all([issparse(adata.layers[layer]) for layer in adata.layers])
+ assert np.all([issparse(adata.obsm[name]) for name in adata.obsm])
+
+ @given(
+ adata=get_adata(
+ n_obs=2, n_vars=2, layer_keys=["unspliced", "spliced"], obsm_keys="X_umap"
+ )
+ )
+ def test_custom_adata_generation(self, adata: AnnData):
+ assert adata.X.shape == (2, 2)
+ assert len(adata.layers) == 2
+ assert len(adata.obsm) == 1
+ assert set(adata.layers.keys()) == {"unspliced", "spliced"}
+ assert set(adata.obsm.keys()) == {"X_umap"}
+
+
+class TestBase:
+ def _subset_modalities(
+ self,
+ adata: AnnData,
+ n_modalities: int,
+ from_layers: bool = True,
+ from_obsm: bool = True,
+ ):
+ """Subset modalities of an AnnData object."""
+
+ modalities = ["X"]
+ if from_layers:
+ modalities += list(adata.layers.keys())
+ if from_obsm:
+ modalities += list(adata.obsm.keys())
+ return random.sample(modalities, min(len(modalities), n_modalities))
+
+ def _convert_to_float(self, adata: AnnData):
+ """Convert AnnData entries in `layer` and `obsm` into floats."""
+
+ for layer in adata.layers:
+ adata.layers[layer] = adata.layers[layer].astype(float)
+ for obs in adata.obsm:
+ adata.obsm[obs] = adata.obsm[obs].astype(float)
diff --git a/scvelo/core/tests/test_linear_models.py b/scvelo/core/tests/test_linear_models.py
new file mode 100644
index 00000000..49dcca65
--- /dev/null
+++ b/scvelo/core/tests/test_linear_models.py
@@ -0,0 +1,86 @@
+import pytest
+from hypothesis import given
+from hypothesis import strategies as st
+from hypothesis.extra.numpy import arrays
+
+import numpy as np
+from numpy import ndarray
+from numpy.testing import assert_almost_equal, assert_array_equal
+
+from scvelo.core import LinearRegression
+
+
+class TestLinearRegression:
+ @given(
+ x=arrays(
+ float,
+ shape=st.integers(min_value=1, max_value=100),
+ elements=st.floats(
+ min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False
+ ),
+ ),
+ coef=st.floats(
+ min_value=-1000, max_value=1000, allow_infinity=False, allow_nan=False
+ ),
+ )
+ def test_perfect_fit(self, x: ndarray, coef: float):
+ lr = LinearRegression()
+ lr.fit(x, x * coef)
+
+ assert lr.intercept_ == 0
+ if set(x) != {0}: # fit is only unique if x is non-trivial
+ assert_almost_equal(lr.coef_, coef)
+
+ @given(
+ x=arrays(
+ float,
+ shape=st.tuples(
+ st.integers(min_value=1, max_value=100),
+ st.integers(min_value=1, max_value=100),
+ ),
+ elements=st.floats(
+ min_value=-1e3, max_value=1e3, allow_infinity=False, allow_nan=False
+ ),
+ ),
+ coef=arrays(
+ float,
+ shape=100,
+ elements=st.floats(
+ min_value=-1000, max_value=1000, allow_infinity=False, allow_nan=False
+ ),
+ ),
+ )
+ # TODO: Extend test to use `percentile`. Zero columns (after trimming) make the
+ # previous implementation of the unit test fail
+ # TODO: Check why test fails if number of columns is increased to e.g. 1000 (500)
+ def test_perfect_fit_2d(self, x: ndarray, coef: ndarray):
+ coef = coef[: x.shape[1]]
+ lr = LinearRegression()
+ lr.fit(x, x * coef)
+
+ assert lr.coef_.shape == (x.shape[1],)
+ assert lr.intercept_.shape == (x.shape[1],)
+ assert_array_equal(lr.intercept_, np.zeros(x.shape[1]))
+ if set(x.flatten()) != {0}: # fit is only unique if x is non-trivial
+ assert_almost_equal(lr.coef_, coef)
+
+ # TODO: Use hypothesis
+ # TODO: Integrate into `test_perfect_fit_2d`
+ @pytest.mark.parametrize(
+ "x, coef, intercept",
+ [
+ (np.array([[0], [1], [2], [3]]), 0, 1),
+ (np.array([[0], [1], [2], [3]]), 2, 1),
+ (np.array([[0], [1], [2], [3]]), 2, -1),
+ ],
+ )
+ def test_perfect_fit_with_intercept(
+ self, x: ndarray, coef: float, intercept: float
+ ):
+ lr = LinearRegression(fit_intercept=True, positive_intercept=False)
+ lr.fit(x, x * coef + intercept)
+
+ assert lr.coef_.shape == (x.shape[1],)
+ assert lr.intercept_.shape == (x.shape[1],)
+ assert_array_equal(lr.intercept_, intercept)
+ assert_array_equal(lr.coef_, coef)
diff --git a/scvelo/core/tests/test_metrics.py b/scvelo/core/tests/test_metrics.py
new file mode 100644
index 00000000..81f54070
--- /dev/null
+++ b/scvelo/core/tests/test_metrics.py
@@ -0,0 +1,24 @@
+from hypothesis import given
+from hypothesis import strategies as st
+from hypothesis.extra.numpy import arrays
+
+import numpy as np
+from numpy import ndarray
+
+from scvelo.core import l2_norm
+
+
+# TODO: Extend test to generate sparse inputs as well
+@given(
+ a=arrays(
+ float,
+ shape=st.integers(min_value=1, max_value=100),
+ elements=st.floats(max_value=1e3, allow_infinity=False, allow_nan=False),
+ ),
+ axis=st.integers(min_value=0, max_value=1),
+)
+def test_l2_norm(a: ndarray, axis: int):
+ if a.ndim == 1:
+ np.allclose(np.linalg.norm(a), l2_norm(a, axis=axis))
+ else:
+ np.allclose(np.linalg.norm(a, axis=axis), l2_norm(a, axis=axis))
diff --git a/scvelo/core/tests/test_models.py b/scvelo/core/tests/test_models.py
new file mode 100644
index 00000000..bca814fb
--- /dev/null
+++ b/scvelo/core/tests/test_models.py
@@ -0,0 +1,90 @@
+from typing import List
+
+import pytest
+from hypothesis import given
+from hypothesis import strategies as st
+from hypothesis.extra.numpy import arrays
+
+import numpy as np
+from numpy import ndarray
+from scipy.integrate import odeint
+
+from scvelo.core import SplicingDynamics
+
+
+class TestSplicingDynamics:
+ @given(
+ alpha=st.floats(min_value=0, allow_infinity=False),
+ beta=st.floats(min_value=0, max_value=1, exclude_min=True),
+ gamma=st.floats(min_value=0, max_value=1, exclude_min=True),
+ initial_state=st.lists(
+ st.floats(min_value=0, allow_infinity=False), min_size=2, max_size=2
+ ),
+ t=arrays(
+ float,
+ shape=st.integers(min_value=1, max_value=100),
+ elements=st.floats(
+ min_value=0, max_value=1e3, allow_infinity=False, allow_nan=False
+ ),
+ ),
+ with_keys=st.booleans(),
+ )
+ def test_output_form(
+ self,
+ alpha: float,
+ beta: float,
+ gamma: float,
+ initial_state: List[float],
+ t: ndarray,
+ with_keys: bool,
+ ):
+ if beta == gamma:
+ gamma = gamma + 1e-6
+
+ splicing_dynamics = SplicingDynamics(
+ alpha=alpha, beta=beta, gamma=gamma, initial_state=initial_state
+ )
+ solution = splicing_dynamics.get_solution(t=t, with_keys=with_keys)
+
+ if not with_keys:
+ assert type(solution) == ndarray
+ assert solution.shape == (len(t), 2)
+ else:
+ assert len(solution) == 2
+ assert type(solution) == dict
+ assert list(solution.keys()) == ["u", "s"]
+ assert all([len(var) == len(t) for var in solution.values()])
+
+ # TODO: Check how / if hypothesis can be used instead.
+ @pytest.mark.parametrize(
+ "alpha, beta, gamma, initial_state",
+ [
+ (5, 0.5, 0.4, [0, 1]),
+ ],
+ )
+ def test_solution(self, alpha, beta, gamma, initial_state):
+ def model(y, t, alpha, beta, gamma):
+ dydt = np.zeros(2)
+ dydt[0] = alpha - beta * y[0]
+ dydt[1] = beta * y[0] - gamma * y[1]
+
+ return dydt
+
+ t = np.linspace(0, 20, 10000)
+ splicing_dynamics = SplicingDynamics(
+ alpha=alpha, beta=beta, gamma=gamma, initial_state=initial_state
+ )
+ exact_solution = splicing_dynamics.get_solution(t=t)
+
+ numerical_solution = odeint(
+ model,
+ np.array(initial_state),
+ t,
+ args=(
+ alpha,
+ beta,
+ gamma,
+ ),
+ )
+
+ assert np.allclose(numerical_solution, exact_solution)
diff --git a/scvelo/datasets.py b/scvelo/datasets.py
index 5dc265fb..484d4bd0 100644
--- a/scvelo/datasets.py
+++ b/scvelo/datasets.py
@@ -1,11 +1,14 @@
"""Builtin Datasets.
"""
-from .read_load import read, load
-from .preprocessing.utils import cleanup
-from anndata import AnnData
+
import numpy as np
import pandas as pd
+from anndata import AnnData
+from scanpy import read
+
+from scvelo.core import cleanup, SplicingDynamics
+from .read_load import load
url_datadir = "https://github.com/theislab/scvelo_notebooks/raw/master/"
@@ -24,8 +27,6 @@ def toy_data(n_obs=None):
Returns `adata` object
"""
- """Random samples from Dentate Gyrus.
- """
adata_dg = dentategyrus()
if n_obs is not None:
@@ -40,7 +41,7 @@ def toy_data(n_obs=None):
def dentategyrus(adjusted=True):
"""Dentate Gyrus neurogenesis.
- Data from `Hochgerner et al. (2018) `_.
+ Data from `Hochgerner et al. (2018) `_.
Dentate gyrus (DG) is part of the hippocampus involved in learning, episodic memory
formation and spatial coding. The experiment from the developing DG comprises two
@@ -58,7 +59,7 @@ def dentategyrus(adjusted=True):
Returns
-------
Returns `adata` object
- """
+ """ # noqa E501
if adjusted:
filename = "data/DentateGyrus/10X43_1.h5ad"
@@ -89,12 +90,16 @@ def dentategyrus(adjusted=True):
def forebrain():
"""Developing human forebrain.
- Forebrain tissue of a week 10 embryo, focusing on glutamatergic neuronal lineage.
+ From `La Manno et al. (2018) `_.
+
+ Forebrain tissue of a human week 10 embryo, focusing on glutamatergic neuronal
+ lineage, obtained from elective routine abortions (10 weeks post-conception).
Returns
-------
Returns `adata` object
- """
+ """ # noqa E501
+
filename = "data/ForebrainGlut/hgForebrainGlut.loom"
url = "http://pklab.med.harvard.edu/velocyto/hgForebrainGlut/hgForebrainGlut.loom"
adata = read(filename, backup_url=url, cleanup=True, sparse=True, cache=True)
@@ -103,9 +108,9 @@ def forebrain():
def pancreas():
- """Pancreatic endocrinogenesis
+ """Pancreatic endocrinogenesis.
- Data from `Bastidas-Ponce et al. (2019) `_.
+ Data from `Bastidas-Ponce et al. (2019) `_.
Pancreatic epithelial and Ngn3-Venus fusion (NVF) cells during secondary transition
with transcriptome profiles sampled from embryonic day 15.5.
@@ -121,7 +126,8 @@ def pancreas():
Returns
-------
Returns `adata` object
- """
+ """ # noqa E501
+
filename = "data/Pancreas/endocrinogenesis_day15.h5ad"
url = f"{url_datadir}data/Pancreas/endocrinogenesis_day15.h5ad"
adata = read(filename, backup_url=url, sparse=True, cache=True)
@@ -132,6 +138,173 @@ def pancreas():
pancreatic_endocrinogenesis = pancreas # restore old conventions
+def dentategyrus_lamanno():
+ """Dentate Gyrus neurogenesis.
+
+ From `La Manno et al. (2018) `_.
+
+ The experiment from the developing mouse hippocampus comprises two time points
+ (P0 and P5) and reveals the complex manifold with multiple branching lineages
+ towards astrocytes, oligodendrocyte precursors (OPCs), granule neurons and pyramidal neurons.
+
+ .. image:: https://user-images.githubusercontent.com/31883718/118401264-49bce380-b665-11eb-8678-e7570ede13d6.png
+ :width: 600px
+
+ Returns
+ -------
+ Returns `adata` object
+ """ # noqa E501
+ filename = "data/DentateGyrus/DentateGyrus.loom"
+ url = "http://pklab.med.harvard.edu/velocyto/DentateGyrus/DentateGyrus.loom"
+ adata = read(filename, backup_url=url, sparse=True, cache=True)
+ adata.var_names_make_unique()
+ adata.obsm["X_tsne"] = np.column_stack([adata.obs["TSNE1"], adata.obs["TSNE2"]])
+ adata.obs["clusters"] = adata.obs["ClusterName"]
+ cleanup(adata, clean="obs", keep=["Age", "clusters"])
+
+ adata.uns["clusters_colors"] = {
+ "RadialGlia": [0.95, 0.6, 0.1],
+ "RadialGlia2": [0.85, 0.3, 0.1],
+ "ImmAstro": [0.8, 0.02, 0.1],
+ "GlialProg": [0.81, 0.43, 0.72352941],
+ "OPC": [0.61, 0.13, 0.72352941],
+ "nIPC": [0.9, 0.8, 0.3],
+ "Nbl1": [0.7, 0.82, 0.6],
+ "Nbl2": [0.448, 0.85490196, 0.95098039],
+ "ImmGranule1": [0.35, 0.4, 0.82],
+ "ImmGranule2": [0.23, 0.3, 0.7],
+ "Granule": [0.05, 0.11, 0.51],
+ "CA": [0.2, 0.53, 0.71],
+ "CA1-Sub": [0.1, 0.45, 0.3],
+ "CA2-3-4": [0.3, 0.35, 0.5],
+ }
+ return adata
+
+
+def gastrulation():
+ """Mouse gastrulation.
+
+ Data from `Pijuan-Sala et al. (2019) `_.
+
+ Gastrulation represents a key developmental event during which embryonic pluripotent
+ cells diversify into lineage-specific precursors that will generate the adult organism.
+
+ This data contains the erythrocyte lineage from Pijuan-Sala et al. (2019).
+ The experiment reveals the molecular map of mouse gastrulation and early organogenesis.
+ It comprises transcriptional profiles of 116,312 single cells from mouse embryos
+ collected at nine sequential time points ranging from 6.5 to 8.5 days post-fertilization.
+ It served to explore the complex events involved in the convergence of visceral and primitive streak-derived endoderm.
+
+ .. image:: https://user-images.githubusercontent.com/31883718/130636066-3bae153e-1626-4d11-8f38-6efab5b81c1c.png
+ :width: 600px
+
+ Returns
+ -------
+ Returns `adata` object
+ """ # noqa E501
+ filename = "data/Gastrulation/gastrulation.h5ad"
+ url = "https://ndownloader.figshare.com/files/28095525"
+ adata = read(filename, backup_url=url, sparse=True, cache=True)
+ adata.var_names_make_unique()
+ return adata
+
+
+def gastrulation_e75():
+ """Mouse gastrulation subset to E7.5.
+
+ Data from `Pijuan-Sala et al. (2019) `_.
+
+ Gastrulation represents a key developmental event during which embryonic pluripotent
+ cells diversify into lineage-specific precursors that will generate the adult organism.
+
+ .. image:: https://user-images.githubusercontent.com/31883718/130636292-7f2a599b-ded4-4616-99d7-604d2f324531.png
+ :width: 600px
+
+ Returns
+ -------
+ Returns `adata` object
+ """ # noqa E501
+ filename = "data/Gastrulation/gastrulation_e75.h5ad"
+ url = "https://ndownloader.figshare.com/files/30439878"
+ adata = read(filename, backup_url=url, sparse=True, cache=True)
+ adata.var_names_make_unique()
+ return adata
+
+
+def gastrulation_erythroid():
+ """Mouse gastrulation subset to erythroid lineage.
+
+ Data from `Pijuan-Sala et al. (2019) `_.
+
+ Gastrulation represents a key developmental event during which embryonic pluripotent
+ cells diversify into lineage-specific precursors that will generate the adult organism.
+
+ .. image:: https://user-images.githubusercontent.com/31883718/118402002-40814600-b668-11eb-8bfc-dbece2b2b34e.png
+ :width: 600px
+
+ Returns
+ -------
+ Returns `adata` object
+ """ # noqa E501
+ filename = "data/Gastrulation/erythroid_lineage.h5ad"
+ url = "https://ndownloader.figshare.com/files/27686871"
+ adata = read(filename, backup_url=url, sparse=True, cache=True)
+ adata.var_names_make_unique()
+ return adata
+
+
+def bonemarrow():
+ """Human bone marrow.
+
+ Data from `Setty et al. (2019) `_.
+
+ The bone marrow is the primary site of new blood cell production or haematopoiesis.
+ It is composed of hematopoietic cells, marrow adipose tissue, and supportive stromal cells.
+
+ This dataset served to detect important landmarks of hematopoietic differentiation, to
+ identify key transcription factors that drive lineage fate choice and to closely track when cells lose plasticity.
+
+ .. image:: https://user-images.githubusercontent.com/31883718/118402252-68bd7480-b669-11eb-9ef3-5f992b74a2d3.png
+ :width: 600px
+
+ Returns
+ -------
+ Returns `adata` object
+ """ # noqa E501
+ filename = "data/BoneMarrow/human_cd34_bone_marrow.h5ad"
+ url = "https://ndownloader.figshare.com/files/27686835"
+ adata = read(filename, backup_url=url, sparse=True, cache=True)
+ adata.var_names_make_unique()
+ return adata
+
+
+def pbmc68k():
+ """Peripheral blood mononuclear cells.
+
+ Data from `Zheng et al. (2017) `_.
+
+ This experiment contains 68k peripheral blood mononuclear cells (PBMC) measured using 10X.
+
+ PBMCs are a diverse mixture of highly specialized immune cells.
+ They originate from hematopoietic stem cells (HSCs) that reside in the bone marrow
+ and give rise to all blood cells of the immune system (hematopoiesis).
+ HSCs give rise to myeloid (monocytes, macrophages, granulocytes, megakaryocytes, dendritic cells, erythrocytes)
+ and lymphoid (T cells, B cells, NK cells) lineages.
+
+ .. image:: https://user-images.githubusercontent.com/31883718/118402351-e1243580-b669-11eb-8256-4a49c299da3d.png
+ :width: 600px
+
+ Returns
+ -------
+ Returns `adata` object
+ """ # noqa E501
+ filename = "data/PBMC/pbmc68k.h5ad"
+ url = "https://ndownloader.figshare.com/files/27686886"
+ adata = read(filename, backup_url=url, sparse=True, cache=True)
+ adata.var_names_make_unique()
+ return adata
+
+
def simulation(
n_obs=300,
n_vars=None,
@@ -159,20 +332,23 @@ def simulation(
Returns
-------
Returns `adata` object
- """
- from .tools.dynamical_model_utils import vectorize, mRNA
+ """ # noqa E501
+
+ from .tools.dynamical_model_utils import vectorize
np.random.seed(random_seed)
def draw_poisson(n):
- from random import uniform, seed # draw from poisson
+ from random import seed, uniform # draw from poisson
seed(random_seed)
t = np.cumsum([-0.1 * np.log(uniform(0, 1)) for _ in range(n - 1)])
return np.insert(t, 0, 0) # prepend t0=0
def simulate_dynamics(tau, alpha, beta, gamma, u0, s0, noise_model, noise_level):
- ut, st = mRNA(tau, u0, s0, alpha, beta, gamma)
+ ut, st = SplicingDynamics(
+ alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0]
+ ).get_solution(tau, stacked=False)
if noise_model == "normal": # add noise
ut += np.random.normal(
scale=noise_level * np.percentile(ut, 99) / 10, size=len(ut)
@@ -237,10 +413,13 @@ def cycle(array, n_vars=None):
U = np.zeros(shape=(len(t), n_vars))
S = np.zeros(shape=(len(t), n_vars))
+ def is_list(x):
+ return isinstance(x, (tuple, list, np.ndarray))
+
for i in range(n_vars):
- alpha_i = alpha[i] if isinstance(alpha, (tuple, list, np.ndarray)) else alpha
- beta_i = beta[i] if isinstance(beta, (tuple, list, np.ndarray)) else beta
- gamma_i = gamma[i] if isinstance(gamma, (tuple, list, np.ndarray)) else gamma
+ alpha_i = alpha[i] if is_list(alpha) and len(alpha) != n_obs else alpha
+ beta_i = beta[i] if is_list(beta) and len(beta) != n_obs else beta
+ gamma_i = gamma[i] if is_list(gamma) and len(gamma) != n_obs else gamma
tau, alpha_vec, u0_vec, s0_vec = vectorize(
t, t_[i], alpha_i, beta_i, gamma_i, alpha_=alpha_, u0=0, s0=0
)
@@ -259,6 +438,13 @@ def cycle(array, n_vars=None):
noise_level[i],
)
+ if is_list(alpha) and len(alpha) == n_obs:
+ alpha = np.nan
+ if is_list(beta) and len(beta) == n_obs:
+ beta = np.nan
+ if is_list(gamma) and len(gamma) == n_obs:
+ gamma = np.nan
+
obs = {"true_t": t.round(2)}
var = {
"true_t_": t_[:n_vars],
diff --git a/scvelo/logging.py b/scvelo/logging.py
index f140a48b..7e79c1c1 100644
--- a/scvelo/logging.py
+++ b/scvelo/logging.py
@@ -1,14 +1,15 @@
"""Logging and Profiling
"""
-
-from . import settings
-from sys import stdout
from datetime import datetime
-from time import time as get_time
from platform import python_version
+from sys import stdout
+from time import time as get_time
+
+from packaging.version import parse
+
from anndata.logging import get_memory_usage
-from anndata.logging import print_memory_usage
+from scvelo import settings
_VERBOSITY_LEVELS_FROM_STRINGS = {"error": 0, "warn": 1, "info": 2, "hint": 3}
@@ -89,7 +90,7 @@ def msg(
if reset:
try:
settings._previous_memory_usage, _ = get_memory_usage()
- except:
+ except Exception:
pass
settings._previous_time = get_time()
if time:
@@ -164,7 +165,7 @@ def __init__(self):
def run(self):
try:
self.result = func(*args, **kwargs)
- except:
+ except Exception:
pass
it = InterruptableThread()
@@ -174,7 +175,7 @@ def run(self):
def get_latest_pypi_version():
- from subprocess import check_output, CalledProcessError
+ from subprocess import CalledProcessError, check_output
try: # needs to work offline as well
result = check_output(["pip", "search", "scvelo"])
@@ -189,7 +190,7 @@ def check_if_latest_version():
latest_version = timeout(
get_latest_pypi_version, timeout_duration=2, default="0.0.0"
)
- if __version__.rsplit(".dev")[0] < latest_version.rsplit(".dev")[0]:
+ if parse(__version__.rsplit(".dev")[0]) < parse(latest_version.rsplit(".dev")[0]):
warn(
"There is a newer scvelo version available on PyPI:\n",
"Your version: \t\t",
@@ -297,7 +298,8 @@ def profiler(command, filename="profile.stats", n_stats=10):
n_stats: int or None
Number of top stats to show.
"""
- import cProfile, pstats
+ import cProfile
+ import pstats
cProfile.run(command, filename)
stats = pstats.Stats(filename).strip_dirs().sort_stats("time")
diff --git a/scvelo/pl.py b/scvelo/pl.py
index 3b819fc4..f6c2497a 100644
--- a/scvelo/pl.py
+++ b/scvelo/pl.py
@@ -1 +1 @@
-from scvelo.plotting import *
+from scvelo.plotting import * # noqa
diff --git a/scvelo/plotting/__init__.py b/scvelo/plotting/__init__.py
index 13eeb61f..8ca74562 100644
--- a/scvelo/plotting/__init__.py
+++ b/scvelo/plotting/__init__.py
@@ -1,14 +1,40 @@
-from .scatter import scatter, umap, tsne, diffmap, phate, draw_graph, pca
+from scanpy.plotting import paga_compare, rank_genes_groups
+
from .gridspec import gridspec
+from .heatmap import heatmap
+from .paga import paga
+from .proportions import proportions
+from .scatter import diffmap, draw_graph, pca, phate, scatter, tsne, umap
+from .simulation import simulation
+from .summary import summary
+from .utils import hist, plot
from .velocity import velocity
from .velocity_embedding import velocity_embedding
from .velocity_embedding_grid import velocity_embedding_grid
from .velocity_embedding_stream import velocity_embedding_stream
from .velocity_graph import velocity_graph
-from .heatmap import heatmap
-from .proportions import proportions
-from .utils import hist, plot
-from .simulation import simulation
-from scanpy.plotting import paga_compare, rank_genes_groups
-from .summary import summary
-from .paga import paga
+
+__all__ = [
+ "diffmap",
+ "draw_graph",
+ "gridspec",
+ "heatmap",
+ "hist",
+ "paga",
+ "paga_compare",
+ "pca",
+ "phate",
+ "plot",
+ "proportions",
+ "rank_genes_groups",
+ "scatter",
+ "simulation",
+ "summary",
+ "tsne",
+ "umap",
+ "velocity",
+ "velocity_embedding",
+ "velocity_embedding_grid",
+ "velocity_embedding_stream",
+ "velocity_graph",
+]
diff --git a/scvelo/plotting/docs.py b/scvelo/plotting/docs.py
index 015e8d92..3cd0e44d 100644
--- a/scvelo/plotting/docs.py
+++ b/scvelo/plotting/docs.py
@@ -17,7 +17,8 @@ def dec(obj):
doc_scatter = """\
basis: `str` or list of `str` (default: `None`)
- Key for embedding. If not specified, use 'umap', 'tsne' or 'pca' (ordered by preference).
+ Key for embedding. If not specified, use 'umap', 'tsne' or 'pca' (ordered by
+ preference).
vkey: `str` or list of `str` (default: `None`)
Key for velocity / steady-state ratio to be visualized.
color: `str`, list of `str` or `None` (default: `None`)
@@ -44,10 +45,10 @@ def dec(obj):
Specify percentile for continuous coloring.
groups: `str` or list of `str` (default: `all groups`)
Restrict to a few categories in categorical observation annotation.
- Multiple categories can be passed as list with ['cluster_1', 'cluster_3'],
+ Multiple categories can be passed as list with ['cluster_1', 'cluster_3'],
or as string with 'cluster_1, cluster_3'.
sort_order: `bool` (default: `True`)
- For continuous annotations used as color parameter,
+ For continuous annotations used as color parameter,
plot data points with higher values on top of others.
components: `str` or list of `str` (default: '1,2')
For instance, ['1,2', '2,3'].
@@ -60,11 +61,15 @@ def dec(obj):
Legend font size.
legend_fontweight: {'normal', 'bold', ...} (default: `None`)
Legend font weight. A numeric value in range 0-1000 or a string.
- Defaults to 'bold' if `legend_loc = 'on data'`, otherwise to 'normal'.
+ Defaults to 'bold' if `legend_loc = 'on data'`, otherwise to 'normal'.
Available are `['light', 'normal', 'medium', 'semibold', 'bold', 'heavy', 'black']`.
-legend_fontoutline
+legend_fontoutline: float (default: `None`)
Line width of the legend font outline in pt. Draws a white outline using
the path effect :class:`~matplotlib.patheffects.withStroke`.
+legend_align_text: bool or str (default: `None`)
+ Aligns the positions of the legend texts. Set the axis along which the best
+ alignment should be determined. This can be 'y' or True (vertically),
+ 'x' (horizontally), or 'xy' (best alignment in both directions).
right_margin: `float` or list of `float` (default: `None`)
Adjust the width of the space right of each plotting panel.
left_margin: `float` or list of `float` (default: `None`)
@@ -84,42 +89,46 @@ def dec(obj):
ylim: tuple, e.g. [0,1] or `None` (default: `None`)
Restrict y-limits of the axis.
add_density: `bool` or `str` or `None` (default: `None`)
- Whether to show density of values along x and y axes.
+ Whether to show density of values along x and y axes.
Color of the density plot can also be passed as `str`.
add_assignments: `bool` or `str` or `None` (default: `None`)
- Whether to add assignments to the model curve.
+ Whether to add assignments to the model curve.
Color of the assignments can also be passed as `str`.
add_linfit: `bool` or `str` or `None` (default: `None`)
- Whether to add linear regression fit to the data points.
+ Whether to add linear regression fit to the data points.
Color of the line can also be passed as `str`.
- Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`.
+ Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`.
A colored regression line with intercept is obtained with `'intercept, blue'`.
add_polyfit: `bool` or `str` or `int` or `None` (default: `None`)
- Whether to add polynomial fit to the data points. Color of the polyfit plot can also
- be passed as `str`. The degree of the polynomial fit can be passed as `int`
- (default is 2 for quadratic fit).
- Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`.
+ Whether to add polynomial fit to the data points. Color of the polyfit plot can also
+ be passed as `str`. The degree of the polynomial fit can be passed as `int`
+ (default is 2 for quadratic fit).
+ Fitting with or without an intercept by passing `'intercept'` or `'no_intercept'`.
A colored regression line with intercept is obtained with `'intercept, blue'`.
add_rug: `str` or `None` (default: `None`)
- If categorical observation annotation (e.g. 'clusters') is given, a rugplot is
+ If categorical observation annotation (e.g. 'clusters') is given, a rugplot is
attached to the x-axis showing the data membership to each of the categories.
add_text: `str` (default: `None`)
Text to be added to the plot, passed as `str`.
-add_text_pos: `tuple`, e.g. [0.05, 0.95] (defaut: `[0.05, 0.95]`)
+add_text_pos: `tuple`, e.g. [0.05, 0.95] (defaut: `[0.05, 0.95]`)
Text position. Default is `[0.05, 0.95]`, positioning the text at top right corner.
+add_margin: `float` (default: `None`)
+ A value between [-1, 1] to add (positive) and reduce (negative) figure margins.
add_outline: `bool` or `str` (default: `False`)
- Whether to show an outline around scatter plot dots.
+ Whether to show an outline around scatter plot dots.
Alternatively a string of cluster names can be passed, e.g. 'cluster_1, clusters_3'.
outline_width: tuple type `scalar` or `None` (default: `(0.3, 0.05)`)
Width of the inner and outer outline
outline_color: tuple of type `str` or `None` (default: `('black', 'white')`)
Inner and outer matplotlib color of the outline
n_convolve: `int` or `None` (default: `None`)
- If `int` is given, data is smoothed by convolution
+ If `int` is given, data is smoothed by convolution
along the x-axis with kernel size `n_convolve`.
smooth: `bool` or `int` (default: `None`)
- Whether to convolve/average the color values over the nearest neighbors.
+ Whether to convolve/average the color values over the nearest neighbors.
If `int`, it specifies number of neighbors.
+normalize_data: `bool` (default: `None`)
+ Whether to rescale values for x, y to [0,1].
rescale_color: `tuple` (default: `None`)
Boundaries for color rescaling, e.g. [0, 1], setting min/max values of the colorbar.
color_gradients: `str` or `np.ndarray` (default: `None`)
@@ -139,7 +148,7 @@ def dec(obj):
show: `bool`, optional (default: `None`)
Show the plot, do not return axis.
save: `bool` or `str`, optional (default: `None`)
- If `True` or a `str`, save the figure. A string is appended to the default filename.
+ If `True` or a `str`, save the figure. A string is appended to the default filename.
Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
ax: `matplotlib.Axes`, optional (default: `None`)
A matplotlib axes object. Only works if plotting a single component.\
diff --git a/scvelo/plotting/gridspec.py b/scvelo/plotting/gridspec.py
index ee440dfd..a0285c61 100644
--- a/scvelo/plotting/gridspec.py
+++ b/scvelo/plotting/gridspec.py
@@ -1,13 +1,14 @@
+from functools import partial
+
+import matplotlib.pyplot as pl
+
# todo: auto-complete and docs wrapper
from .scatter import scatter
+from .utils import get_figure_params, hist
from .velocity_embedding import velocity_embedding
from .velocity_embedding_grid import velocity_embedding_grid
from .velocity_embedding_stream import velocity_embedding_stream
from .velocity_graph import velocity_graph
-from .utils import hist, get_figure_params
-
-import matplotlib.pyplot as pl
-from functools import partial
def _wraps_plot(wrapper, func):
diff --git a/scvelo/plotting/heatmap.py b/scvelo/plotting/heatmap.py
index c817de19..aeca9968 100644
--- a/scvelo/plotting/heatmap.py
+++ b/scvelo/plotting/heatmap.py
@@ -1,14 +1,16 @@
import numpy as np
-import matplotlib.pyplot as pl
-from matplotlib import rcParams
-from matplotlib.colors import ColorConverter
import pandas as pd
-from pandas import unique, isnull
from scipy.sparse import issparse
-from .. import logging as logg
-from .utils import is_categorical, interpret_colorkey, savefig_or_show, to_list
-from .utils import set_colors_for_categorical_obs, strings_to_categoricals
+from scvelo import logging as logg
+from .utils import (
+ interpret_colorkey,
+ is_categorical,
+ savefig_or_show,
+ set_colors_for_categorical_obs,
+ strings_to_categoricals,
+ to_list,
+)
def heatmap(
@@ -25,8 +27,9 @@ def heatmap(
colorbar=None,
col_cluster=False,
row_cluster=False,
- figsize=(8, 4),
+ context=None,
font_scale=None,
+ figsize=(8, 4),
show=None,
save=None,
**kwargs,
@@ -48,6 +51,8 @@ def heatmap(
String denoting matplotlib color map.
col_color: `str` or list of `str` (default: `None`)
String denoting matplotlib color map to use along the columns.
+ palette: list of `str` (default: `'viridis'`)
+ Colors to use for plotting groups (categorical annotation).
n_convolve: `int` or `None` (default: `30`)
If `int` is given, data is smoothed by convolution
along the x-axis with kernel size n_convolve.
@@ -58,9 +63,13 @@ def heatmap(
Wether to sort the expression values given by xkey.
colorbar: `bool` or `None` (default: `None`)
Whether to show colorbar.
- {row,col}_cluster : bool, optional
+ {row,col}_cluster : `bool` or `None`
If True, cluster the {rows, columns}.
- figsize: tuple (default: `(7,5)`)
+ context : `None`, or one of {paper, notebook, talk, poster}
+ A dictionary of parameters or the name of a preconfigured set.
+ font_scale : float, optional
+ Scaling factor to scale the size of the font elements.
+ figsize: tuple (default: `(8,4)`)
Figure size.
show: `bool`, optional (default: `None`)
Show the plot, do not return axis.
@@ -73,7 +82,7 @@ def heatmap(
Returns
-------
- If `show==False` a `matplotlib.Axis`
+ If `show==False` a `matplotlib.Axis`
"""
import seaborn as sns
@@ -98,7 +107,7 @@ def heatmap(
for gene in var_names:
try:
df[gene] = np.convolve(df[gene].values, weights, mode="same")
- except:
+ except Exception:
pass # e.g. all-zero counts or nans cannot be convolved
if sort:
@@ -118,8 +127,6 @@ def heatmap(
set_colors_for_categorical_obs(adata, col, palette)
col_color.append(interpret_colorkey(adata, col)[np.argsort(time)])
- if font_scale is not None:
- sns.set(font_scale=font_scale)
if "dendrogram_ratio" not in kwargs:
kwargs["dendrogram_ratio"] = (
0.1 if row_cluster else 0,
@@ -139,454 +146,21 @@ def heatmap(
figsize=figsize,
)
)
- try:
- cm = sns.clustermap(df.T, **kwargs)
- except:
- logg.warn("Please upgrade seaborn with `pip install -U seaborn`.")
- kwargs.pop("dendrogram_ratio")
- kwargs.pop("cbar_pos")
- cm = sns.clustermap(df.T, **kwargs)
+
+ args = {}
+ if font_scale is not None:
+ args = {"font_scale": font_scale}
+ context = context or "notebook"
+
+ with sns.plotting_context(context=context, **args):
+ try:
+ cm = sns.clustermap(df.T, **kwargs)
+ except Exception:
+ logg.warn("Please upgrade seaborn with `pip install -U seaborn`.")
+ kwargs.pop("dendrogram_ratio")
+ kwargs.pop("cbar_pos")
+ cm = sns.clustermap(df.T, **kwargs)
savefig_or_show("heatmap", save=save, show=show)
if show is False:
return cm
-
-
-def heatmap_deprecated(
- adata,
- var_names,
- groups=None,
- groupby=None,
- annotations=None,
- use_raw=False,
- layers=None,
- color_map=None,
- color_map_anno=None,
- colorbar=True,
- row_width=None,
- xlabel=None,
- title=None,
- figsize=None,
- dpi=None,
- show=True,
- save=None,
- ax=None,
- **kwargs,
-):
-
- """\
- Plot pseudotimeseries for genes as heatmap.
-
- Arguments
- ---------
- adata: :class:`~anndata.AnnData`
- Annotated data matrix.
- var_names: `str`, list of `str`
- Names of variables to use for the plot.
- groups: `str`, list of `str` or `None` (default: `None`)
- Groups selected to plot. Must be an element of adata.obs[groupby].
- groupby: `str` or `None` (default: `None`)
- Key in adata.obs. Indicates how to group the plot.
- annotations: `str`, list of `str` or `None` (default: `None`)
- Key in adata.obs. Annotations are plotted in the last row.
- use_raw: `bool` (default: `False`)
- If true, moments are used instead of raw data.
- layers: `str`, list of `str` or `None` (default: `['X']`)
- Selected layers.
- color_map: `str`, list of `str` or `None` (default: `None`)
- String denoting matplotlib color map for the heat map.
- There must be one list entry for each layer.
- color_map_anno: `str`, list of `str` or `None` (default: `None`)
- String denoting matplotlib color map for the annotations.
- There must be one list entry for each annotation.
- colorbar: `bool` (default: `True`)
- If True, a colormap for each layer is added on the right bottom corner.
- row_width: `float` (default: `None`)
- Constant width of all rows.
- xlabel:
- Label for the x-axis.
- title: `str` or `None` (default: `None`)
- Main plot title.
- figsize: tuple (default: `(7,5)`)
- Figure size.
- dpi: `int` (default: 80)
- Figure dpi.
- show: `bool`, optional (default: `None`)
- Show the plot, do not return axis.
- save: `bool` or `str`, optional (default: `None`)
- If `True` or a `str`, save the figure.
- Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
- ax: `matplotlib.Axes`, optional (default: `None`)
- A matplotlib axes object. Only works if plotting a single component.
-
- Returns
- -------
- If `show==False` a `matplotlib.Axis`
- """
-
- # catch
- if "velocity_pseudotime" not in adata.obs.keys():
- raise ValueError(
- "A function requires computation of the pseudotime"
- "for ordering at single-cell resolution"
- )
- if layers is None:
- layers = ["X"]
- if annotations is None:
- annotations = []
- if isinstance(var_names, str):
- var_names = [var_names]
- if len(var_names) == 0:
- var_names = np.arange(adata.X.shape[1])
- if var_names.ndim == 2:
- var_names = var_names[:, 0]
- var_names = [name for name in var_names if name in adata.var_names]
- if len(var_names) == 0:
- raise ValueError(
- "The specified var_names are all not" "contained in the adata.var_names."
- )
-
- if layers is None:
- layers = ["X"]
- if isinstance(layers, str):
- layers = [layers]
- layers = [layer for layer in layers if layer in adata.layers.keys() or layer == "X"]
- if len(layers) == 0:
- raise ValueError(
- "The selected layers are not contained" "in adata.layers.keys()."
- )
- if not use_raw:
- layers = np.array(layers)
- if "X" in layers:
- layers[np.array([layer == "X" for layer in layers])] = "Ms"
- if "spliced" in layers:
- layers[np.array([layer == "spliced" for layer in layers])] = "Ms"
- if "unspliced" in layers:
- layers[np.array([layer == "unspliced" for layer in layers])] = "Ms"
- layers = list(layers)
- if "Ms" in layers and "Ms" not in adata.layers.keys():
- raise ValueError(
- "Moments have to be computed before" "using this plot function."
- )
- if "Mu" in layers and "Mu" not in adata.layers.keys():
- raise ValueError(
- "Moments have to be computed before" "using this plot function."
- )
- layers = unique(layers)
-
- # Number of rows to plot
- tot_len = len(var_names) * len(layers) + len(annotations)
-
- # init main figure
- figsize = rcParams["figure.figsize"] if figsize is None else figsize
- if row_width is not None:
- figsize[1] = row_width * tot_len
- ax = pl.figure(figsize=figsize, dpi=dpi).gca() if ax is None else ax
- ax.set_yticks([])
- ax.set_xticks([])
-
- # groups bar
- ax_bounds = ax.get_position().bounds
- if groupby is not None:
- # catch
- if groupby not in adata.obs_keys():
- raise ValueError(
- "The selected groupby is not contained" "in adata.obs_keys()."
- )
- if groups is None: # Then use everything of that obs
- groups = unique(adata.obs.clusters.values)
-
- imlist = []
-
- for igroup, group in enumerate(groups):
- for ivar, var in enumerate(var_names):
- for ilayer, layer in enumerate(layers):
- groups_axis = pl.axes(
- [
- ax_bounds[0] + igroup * ax_bounds[2] / len(groups),
- ax_bounds[1]
- + ax_bounds[3]
- * (tot_len - ivar * len(layers) - ilayer - 1)
- / tot_len,
- ax_bounds[2] / len(groups),
- (ax_bounds[3] - ax_bounds[3] / tot_len * len(annotations))
- / (len(var_names) * len(layers)),
- ]
- )
-
- # Get data to fill and reshape
- dat = adata[:, var]
-
- idx_group = [adata.obs[groupby] == group]
- idx_group = np.array(idx_group[0].tolist())
- idx_var = [vn in var_names for vn in adata.var_names]
- idx_pt = np.array(adata.obs.velocity_pseudotime).argsort()
- idx_pt = idx_pt[
- np.array(
- isnull(np.array(dat.obs.velocity_pseudotime)[idx_pt])
- == False
- )
- ]
-
- if layer == "X":
- laydat = dat.X
- else:
- laydat = dat.layers[layer]
-
- t1, t2, t3 = idx_group, idx_var, idx_pt
- t1 = t1[t3]
- # laydat = laydat[:, t2] # select vars
- laydat = laydat[t3]
- laydat = laydat[t1] # select ordered groups
-
- if issparse(laydat):
- laydat = laydat.A
-
- # transpose X for ordering in direction var_names: up->downwards
- laydat = laydat.T[::-1]
- laydat = laydat.reshape((1, len(laydat))) # ensure 1dimty
-
- # plot
- im = groups_axis.imshow(
- laydat,
- aspect="auto",
- interpolation="nearest",
- cmap=color_map[ilayer],
- )
-
- # Frames
- if ilayer == 0:
- groups_axis.spines["bottom"].set_visible(False)
- elif ilayer == len(layer) - 1:
- groups_axis.spines["top"].set_visible(False)
- else:
- groups_axis.spines["top"].set_visible(False)
- groups_axis.spines["bottom"].set_visible(False)
-
- # Further visuals
- if igroup == 0:
- if colorbar:
- if len(layers) % 2 == 0:
- if ilayer == len(layers) / 2 - 1:
- pl.yticks([0.5], [var])
- else:
- groups_axis.set_yticks([])
- else:
- if ilayer == (len(layers) - 1) / 2:
- pl.yticks([0], [var])
- else:
- groups_axis.set_yticks([])
- else:
- pl.yticks([0], [f"{layer} {var}"])
- else:
- groups_axis.set_yticks([])
-
- groups_axis.set_xticks([])
- if ilayer == 0 and ivar == 0:
- groups_axis.set_title(f"{group}")
- groups_axis.grid(False)
-
- # handle needed as mappable for colorbar
- if igroup == len(groups) - 1:
- imlist.append(im)
-
- # further annotations for each group
- if annotations is not None:
- for ianno, anno in enumerate(annotations):
- anno_axis = pl.axes(
- [
- ax_bounds[0] + igroup * ax_bounds[2] / len(groups),
- ax_bounds[1]
- + ax_bounds[3] / tot_len * (len(annotations) - ianno - 1),
- ax_bounds[2] / len(groups),
- ax_bounds[3] / tot_len,
- ]
- )
- if is_categorical(adata, anno):
- colo = interpret_colorkey(adata, anno)[t3][t1]
- colo.reshape(1, len(colo))
- mapper = np.vectorize(ColorConverter.to_rgb)
- a = mapper(colo)
- a = np.array(a).T
- Y = a.reshape(1, len(colo), 3)
- else:
- Y = np.array(interpret_colorkey(adata, anno))[t3][t1]
- Y = Y.reshape(1, len(Y))
- img = anno_axis.imshow(
- Y, aspect="auto", interpolation="nearest", cmap=color_map_anno
- )
- if igroup == 0:
- anno_axis.set_yticklabels(
- ["", anno, ""]
- ) # , fontsize=ytick_fontsize)
- anno_axis.tick_params(axis="both", which="both", length=0)
- else:
- anno_axis.set_yticklabels([])
- anno_axis.set_yticks([])
- anno_axis.set_xticks([])
- anno_axis.set_xticklabels([])
- anno_axis.grid(False)
- pl.ylim([0.5, -0.5]) # center ticks
-
- else: # groupby is False
- imlist = []
- for ivar, var in enumerate(var_names):
- for ilayer, layer in enumerate(layers):
- ax_bounds = ax.get_position().bounds
- groups_axis = pl.axes(
- [
- ax_bounds[0],
- ax_bounds[1]
- + ax_bounds[3]
- * (tot_len - ivar * len(layers) - ilayer - 1)
- / tot_len,
- ax_bounds[2],
- (ax_bounds[3] - ax_bounds[3] / tot_len * len(annotations))
- / (len(var_names) * len(layers)),
- ]
- )
- # Get data to fill
- dat = adata[:, var]
- idx = np.array(dat.obs.velocity_pseudotime).argsort()
- idx = idx[
- np.array(
- isnull(np.array(dat.obs.velocity_pseudotime)[idx]) == False
- )
- ]
-
- if layer == "X":
- laydat = dat.X
- else:
- laydat = dat.layers[layer]
- laydat = laydat[idx]
- if issparse(laydat):
- laydat = laydat.A
-
- # transpose X for ordering in direction var_names: up->downwards
- laydat = laydat.T[::-1]
- laydat = laydat.reshape((1, len(laydat)))
-
- # plot
- im = groups_axis.imshow(
- laydat,
- aspect="auto",
- interpolation="nearest",
- cmap=color_map[ilayer],
- )
- imlist.append(im)
-
- # Frames
- if ilayer == 0:
- groups_axis.spines["bottom"].set_visible(False)
- elif ilayer == len(layer) - 1:
- groups_axis.spines["top"].set_visible(False)
- else:
- groups_axis.spines["top"].set_visible(False)
- groups_axis.spines["bottom"].set_visible(False)
-
- # Further visuals
- groups_axis.set_xticks([])
- groups_axis.grid(False)
- pl.ylim([0.5, -0.5]) # center
- if colorbar:
- if len(layers) % 2 == 0:
- if ilayer == len(layers) / 2 - 1:
- pl.yticks([0.5], [var])
- else:
- groups_axis.set_yticks([])
- else:
- if ilayer == (len(layers) - 1) / 2:
- pl.yticks([0], [var])
- else:
- groups_axis.set_yticks([])
- else:
- pl.yticks([0], [f"{layer} {var}"])
-
- # further annotations bars
- if annotations is not None:
- for ianno, anno in enumerate(annotations):
- anno_axis = pl.axes(
- [
- ax_bounds[0],
- ax_bounds[1]
- + ax_bounds[3] / tot_len * (len(annotations) - ianno - 1),
- ax_bounds[2],
- ax_bounds[3] / tot_len,
- ]
- )
- dat = adata[:, var_names]
- if is_categorical(dat, anno):
- colo = interpret_colorkey(dat, anno)[idx]
- colo.reshape(1, len(colo))
- mapper = np.vectorize(ColorConverter.to_rgb)
- a = mapper(colo)
- a = np.array(a).T
- Y = a.reshape(1, len(idx), 3)
- else:
- Y = np.array(interpret_colorkey(dat, anno)[idx]).reshape(
- 1, len(idx)
- )
- img = anno_axis.imshow(
- Y, aspect="auto", interpolation="nearest", cmap=color_map_anno
- )
-
- anno_axis.set_yticklabels(["", anno, ""]) # , fontsize=ytick_fontsize)
- anno_axis.tick_params(axis="both", which="both", length=0)
- anno_axis.grid(False)
- anno_axis.set_xticks([])
- anno_axis.set_xticklabels([])
- pl.ylim([-0.5, +0.5])
-
- # Colorbar
- if colorbar:
- if len(layers) > 1:
- # I must admit, this part is chaotic
- for ilayer, layer in enumerate(layers):
- w = 0.015 * 10 / figsize[0] # 0.02 * ax_bounds[2]
- x = ax_bounds[0] + ax_bounds[2] * 0.99 + 1.5 * w + w * 1.2 * ilayer
- y = ax_bounds[1]
- h = ax_bounds[3] * 0.3
- cbaxes = pl.axes([x, y, w, h])
- cb = pl.colorbar(mappable=imlist[ilayer], cax=cbaxes)
- pl.text(
- x - 40 * w,
- y + h * 4,
- layer,
- rotation=45,
- horizontalalignment="left",
- verticalalignment="bottom",
- )
- if ilayer == len(layers) - 1:
- ext = abs(cb.vmin - cb.vmax)
- cb.set_ticks([cb.vmin + 0.07 * ext, cb.vmax - 0.07 * ext])
- cb.ax.set_yticklabels(["Low", "High"]) # vertical colorbar
- else:
- cb.set_ticks([])
- else:
- cbaxes = pl.axes(
- [
- ax_bounds[0] + ax_bounds[2] + 0.01,
- ax_bounds[1],
- 0.02,
- ax_bounds[3] * 0.3,
- ]
- )
- cb = pl.colorbar(mappable=im, cax=cbaxes)
- cb.set_ticks([cb.vmin, cb.vmax])
- cb.ax.set_yticklabels(["Low", "High"])
-
- if xlabel is None:
- xlabel = "velocity" + " " + "pseudotime"
- if title is not None:
- ax.set_title(title, pad=30)
- if len(annotations) == 0:
- ax.set_xlabel(xlabel)
- ax.xaxis.labelpad = 20
-
- # set_label(xlabel, None, fontsize, basis)
- # set_title(title, None, None, fontsize)
- # update_axes(ax, fontsize)
-
- savefig_or_show("heatmap", dpi=dpi, save=save, show=show)
- if not show:
- return ax
diff --git a/scvelo/plotting/paga.py b/scvelo/plotting/paga.py
index 062c52a7..f078c7d6 100644
--- a/scvelo/plotting/paga.py
+++ b/scvelo/plotting/paga.py
@@ -1,20 +1,27 @@
-from .. import settings
-from .. import logging as logg
+import collections.abc as cabc
+import random
+from inspect import signature
-from ..tools.utils import groups_to_bool
-from ..tools.paga import get_igraph_from_adjacency
-from .utils import default_basis, default_size, default_color, get_components
-from .utils import make_unique_list, make_unique_valid_list, savefig_or_show
-from .scatter import scatter
-from .docs import doc_scatter, doc_params
+import numpy as np
+import matplotlib.pyplot as pl
from matplotlib import rcParams
from matplotlib.path import get_path_collection_extents
-import matplotlib.pyplot as pl
-import numpy as np
-from inspect import signature
-import collections.abc as cabc
-import random
+
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.tools.paga import get_igraph_from_adjacency
+from .docs import doc_params, doc_scatter
+from .scatter import scatter
+from .utils import (
+ default_basis,
+ default_color,
+ default_size,
+ get_components,
+ make_unique_list,
+ make_unique_valid_list,
+ savefig_or_show,
+)
@doc_params(scatter=doc_scatter)
@@ -280,11 +287,6 @@ def paga(
size = default_size(adata) / 2 if size is None else size
paga_groups = adata.uns["paga"]["groups"]
- _adata = (
- adata[groups_to_bool(adata, groups, groupby=paga_groups)]
- if groups is not None and paga_groups in adata.obs.keys()
- else adata
- )
if isinstance(node_colors, dict):
paga_kwargs["colorbar"] = False
@@ -376,7 +378,7 @@ def _compute_pos(
if layout == "fa":
try:
import fa2
- except:
+ except Exception:
logg.warn(
"Package 'fa2' is not installed, falling back to layout 'fr'."
"To use the faster and better ForceAtlas2 layout, "
@@ -718,15 +720,18 @@ def _paga_graph(
"""scanpy/_paga_graph with some adjustments for directional graphs.
To be moved back to scanpy once finalized.
"""
+ import warnings
+ from pathlib import Path
+
import networkx as nx
import pandas as pd
import scipy
- import warnings
+ from pandas.api.types import is_categorical_dtype
+
from matplotlib import patheffects
from matplotlib.colors import is_color_like
- from pathlib import Path
+
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
- from pandas.api.types import is_categorical_dtype
node_labels = labels # rename for clarity
if (
@@ -825,8 +830,10 @@ def _paga_graph(
and colors in adata.obs
and is_categorical_dtype(adata.obs[colors])
):
- from scanpy._utils import compute_association_matrix_of_groups
- from scanpy._utils import get_associated_colors_of_groups
+ from scanpy._utils import (
+ compute_association_matrix_of_groups,
+ get_associated_colors_of_groups,
+ )
norm = "reference" if normalize_to_color else "prediction"
_, asso_matrix = compute_association_matrix_of_groups(
@@ -1042,7 +1049,9 @@ def transform_ax_coords(a, b):
color_single.append("grey")
fracs.append(1 - sum(fracs))
wedgeprops = dict(linewidth=0, edgecolor="k", antialiased=True)
- pie_axs[count].pie(fracs, colors=color_single, wedgeprops=wedgeprops)
+ pie_axs[count].pie(
+ fracs, colors=color_single, wedgeprops=wedgeprops, normalize=True
+ )
if node_labels is not None:
text_kwds.update(dict(verticalalignment="center", fontweight=fontweight))
text_kwds.update(dict(horizontalalignment="center", size=fontsize))
@@ -1087,6 +1096,6 @@ def getbb(sc, ax):
result = get_path_collection_extents(
transform.frozen(), [p], [t], [o], transOffset.frozen()
)
- bboxes.append(result.inverse_transformed(ax.transData))
+ bboxes.append(result.transformed(ax.transData.inverted()))
return bboxes
diff --git a/scvelo/plotting/palettes.py b/scvelo/plotting/palettes.py
index 5962a923..5bd9b379 100644
--- a/scvelo/plotting/palettes.py
+++ b/scvelo/plotting/palettes.py
@@ -1,7 +1,10 @@
-"""Color palettes in addition to matplotlib's palettes."""
+from typing import Mapping, Sequence
from matplotlib import cm, colors
+"""Color palettes in addition to matplotlib's palettes."""
+
+
# Colorblindness adjusted vega_10
# See https://github.com/theislab/scanpy/issues/387
vega_10 = list(map(colors.to_hex, cm.tab10.colors))
@@ -80,13 +83,11 @@
# fmt: on
-from typing import Mapping, Sequence
-
-
def _plot_color_cylce(clists: Mapping[str, Sequence[str]]):
import numpy as np
+
import matplotlib.pyplot as plt
- from matplotlib.colors import ListedColormap, BoundaryNorm
+ from matplotlib.colors import BoundaryNorm, ListedColormap
fig, axes = plt.subplots(nrows=len(clists)) # type: plt.Figure, plt.Axes
fig.subplots_adjust(top=0.95, bottom=0.01, left=0.3, right=0.99)
diff --git a/scvelo/plotting/proportions.py b/scvelo/plotting/proportions.py
index 41f5ef66..420a57f7 100644
--- a/scvelo/plotting/proportions.py
+++ b/scvelo/plotting/proportions.py
@@ -1,7 +1,9 @@
-from ..preprocessing.utils import sum_var
+import numpy as np
import matplotlib.pyplot as pl
-import numpy as np
+
+from scvelo.core import sum
+from .utils import savefig_or_show
def proportions(
@@ -16,6 +18,7 @@ def proportions(
dpi=100,
use_raw=True,
show=True,
+ save=None,
):
"""Plot pie chart of spliced/unspliced proprtions.
@@ -43,22 +46,26 @@ def proportions(
Use initial cell sizes before normalization and filtering.
show: `bool` (default: True)
Show the plot, do not return axis.
+ save: `bool` or `str`, optional (default: `None`)
+ If `True` or a `str`, save the figure. A string is appended to the default
+ filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
Returns
-------
Plots the proportions of abundances as pie chart.
"""
+
# get counts per cell for each layer
if layers is None:
layers = ["spliced", "unspliced", "ambigious"]
layers_keys = [key for key in layers if key in adata.layers.keys()]
- counts_layers = [sum_var(adata.layers[key]) for key in layers_keys]
+ counts_layers = [sum(adata.layers[key], axis=1) for key in layers_keys]
if use_raw:
ikey, obs = "initial_size_", adata.obs
counts_layers = [
- obs[ikey + l] if ikey + l in obs.keys() else c
- for l, c in zip(layers_keys, counts_layers)
+ obs[ikey + layer_key] if ikey + layer_key in obs.keys() else c
+ for layer_key, c in zip(layers_keys, counts_layers)
]
counts_total = np.sum(counts_layers, 0)
counts_total += counts_total == 0
@@ -71,7 +78,10 @@ def proportions(
ax = pl.subplot(gspec[0])
if highlight is None:
highlight = "none"
- explode = [0.1 if (l == highlight or l in highlight) else 0 for l in layers_keys]
+ explode = [
+ 0.1 if (layer_key == highlight or layer_key in highlight) else 0
+ for layer_key in layers_keys
+ ]
autopct = "%1.0f%%" if add_labels_pie else None
pie = ax.pie(
@@ -151,7 +161,7 @@ def proportions(
ax2.set_ylabel(groupby, fontweight="bold", fontsize=fontsize * 1.2)
ax2.tick_params(axis="both", which="major", labelsize=fontsize)
ax = [ax, ax2]
- if show:
- pl.show()
- else:
+
+ savefig_or_show("proportions", dpi=dpi, save=save, show=show)
+ if show is False:
return ax
diff --git a/scvelo/plotting/pseudotime.py b/scvelo/plotting/pseudotime.py
index 6f52294d..33067212 100644
--- a/scvelo/plotting/pseudotime.py
+++ b/scvelo/plotting/pseudotime.py
@@ -1,4 +1,7 @@
-from .utils import *
+import numpy as np
+
+import matplotlib.pyplot as pl
+from matplotlib.ticker import MaxNLocator
def principal_curve(adata):
diff --git a/scvelo/plotting/scatter.py b/scvelo/plotting/scatter.py
index 7e1f0562..16aeda8f 100644
--- a/scvelo/plotting/scatter.py
+++ b/scvelo/plotting/scatter.py
@@ -1,10 +1,19 @@
-from .docs import doc_scatter, doc_params
-from .utils import *
-
from inspect import signature
-import matplotlib.pyplot as pl
+
import numpy as np
import pandas as pd
+from pandas import unique
+
+import matplotlib.pyplot as pl
+from matplotlib.colors import is_color_like
+
+from anndata import AnnData
+
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.preprocessing.neighbors import get_connectivities
+from .docs import doc_params, doc_scatter
+from .utils import *
@doc_params(scatter=doc_scatter)
@@ -34,6 +43,7 @@ def scatter(
legend_fontsize=None,
legend_fontweight=None,
legend_fontoutline=None,
+ legend_align_text=None,
xlabel=None,
ylabel=None,
title=None,
@@ -48,11 +58,13 @@ def scatter(
add_rug=None,
add_text=None,
add_text_pos=None,
+ add_margin=None,
add_outline=None,
outline_width=None,
outline_color=None,
n_convolve=None,
smooth=None,
+ normalize_data=None,
rescale_color=None,
color_gradients=None,
dpi=None,
@@ -82,8 +94,9 @@ def scatter(
Returns
-------
- If `show==False` a `matplotlib.Axis`
+ If `show==False` a `matplotlib.Axis`
"""
+
if adata is None and (x is not None and y is not None):
adata = AnnData(np.stack([x, y]).T)
@@ -97,7 +110,7 @@ def scatter(
# keys for figures (fkeys) and multiple plots (mkeys)
fkeys = ["adata", "show", "save", "groups", "ncols", "nrows", "wspace", "hspace"]
- fkeys += ["ax", "kwargs"]
+ fkeys += ["add_margin", "ax", "kwargs"]
mkeys = ["color", "layer", "basis", "components", "x", "y", "xlabel", "ylabel"]
mkeys += ["title", "color_map", "add_text"]
scatter_kwargs = {"show": False, "save": False}
@@ -158,7 +171,7 @@ def scatter(
nrows = int(np.ceil(len(multikey) / ncols))
else:
ncols = int(np.ceil(len(multikey) / nrows))
- if not frameon:
+ if not frameon or frameon == "artist":
lloc, llines = "legend_loc", "legend_loc_lines"
if lloc in scatter_kwargs and scatter_kwargs[lloc] is None:
scatter_kwargs[lloc] = "none"
@@ -303,16 +316,12 @@ def scatter(
basis = default_basis(adata)
if linewidth is None:
linewidth = 1
- if linecolor is None:
- linecolor = "k"
if frameon is None:
frameon = True if not is_embedding else settings._frameon
if isinstance(groups, str):
groups = [groups]
if use_raw is None and basis not in adata.var_names:
use_raw = layer is None and adata.raw is not None
- if projection == "3d":
- from mpl_toolkits.mplot3d import Axes3D
ax, show = get_ax(ax, show, figsize, dpi, projection)
@@ -494,7 +503,7 @@ def scatter(
try:
c += rescale_color[0] - np.nanmin(c)
c *= rescale_color[1] / np.nanmax(c)
- except:
+ except Exception:
logg.warn("Could not rescale colors. Pass a tuple, e.g. [0,1].")
# set vmid to 0 if color values obtained from velocity expression
@@ -521,12 +530,12 @@ def scatter(
if len(x) != len(y):
raise ValueError("x or y do not share the same dimension.")
+ if normalize_data:
+ x = (x - np.nanmin(x)) / (np.nanmax(x) - np.nanmin(x))
+ y = (y - np.nanmin(x)) / (np.nanmax(y) - np.nanmin(y))
+
if not isinstance(c, str):
c = np.ravel(c) if len(np.ravel(c)) == len(x) else c
- if len(c) != len(x):
- c = "grey"
- if not isinstance(color, str) or color != default_color(adata):
- logg.warn("Invalid color key. Using grey instead.")
# store original order of color values
color_array, scatter_array = c, np.stack([x, y]).T
@@ -576,6 +585,13 @@ def scatter(
title = groups[0]
else: # if nothing to be highlighted
add_linfit, add_polyfit, add_density = None, None, None
+ else:
+ idx = None
+
+ if not isinstance(c, str) and len(c) != len(x):
+ c = "grey"
+ if not isinstance(color, str) or color != default_color(adata):
+ logg.warn("Invalid color key. Using grey instead.")
# check if higher value points should be plotted on top
if not isinstance(c, str) and len(c) == len(x):
@@ -583,7 +599,9 @@ def scatter(
if sort_order and not is_categorical(adata, color):
order = np.argsort(c)
elif not sort_order and is_categorical(adata, color):
- counts = get_value_counts(adata, color)
+ counts = get_value_counts(
+ adata[idx] if idx is not None else adata, color
+ )
np.random.seed(0)
nums, p = np.arange(0, len(x)), counts / np.sum(counts)
order = np.random.choice(nums, len(x), replace=False, p=p)
@@ -592,8 +610,9 @@ def scatter(
if isinstance(kwargs["s"], np.ndarray): # sort sizes if array-type
kwargs["s"] = np.array(kwargs["s"])[order]
+ marker = kwargs.pop("marker", ".")
smp = ax.scatter(
- x, y, c=c, alpha=alpha, marker=".", zorder=zorder, **kwargs
+ x, y, c=c, alpha=alpha, marker=marker, zorder=zorder, **kwargs
)
outline_dtypes = (list, tuple, np.ndarray, int, np.int_, str)
@@ -655,6 +674,7 @@ def scatter(
legend_fontweight,
legend_fontsize,
legend_fontoutline,
+ legend_align_text,
groups,
)
if add_density:
@@ -710,7 +730,8 @@ def scatter(
set_label(xlabel, ylabel, fontsize, basis, ax=ax)
set_title(title, layer, color, fontsize, ax=ax)
update_axes(ax, xlim, ylim, fontsize, is_embedding, frameon, figsize)
-
+ if add_margin:
+ set_margin(ax, x, y, add_margin)
if colorbar is not False:
if not isinstance(c, str) and not is_categorical(adata, color):
labelsize = fontsize * 0.75 if fontsize is not None else None
@@ -744,6 +765,7 @@ def trimap(adata, **kwargs):
-------
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
"""
+
return scatter(adata, basis="trimap", **kwargs)
@@ -760,6 +782,7 @@ def umap(adata, **kwargs):
-------
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
"""
+
return scatter(adata, basis="umap", **kwargs)
@@ -776,6 +799,7 @@ def tsne(adata, **kwargs):
-------
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
"""
+
return scatter(adata, basis="tsne", **kwargs)
@@ -792,6 +816,7 @@ def diffmap(adata, **kwargs):
-------
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
"""
+
return scatter(adata, basis="diffmap", **kwargs)
@@ -808,6 +833,7 @@ def phate(adata, **kwargs):
-------
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
"""
+
return scatter(adata, basis="phate", **kwargs)
@@ -824,6 +850,7 @@ def draw_graph(adata, layout=None, **kwargs):
-------
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
"""
+
if layout is None:
layout = f"{adata.uns['draw_graph']['params']['layout']}"
basis = f"draw_graph_{layout}"
@@ -845,4 +872,5 @@ def pca(adata, **kwargs):
-------
If `show==False` a :class:`~matplotlib.axes.Axes` or a list of it.
"""
+
return scatter(adata, basis="pca", **kwargs)
diff --git a/scvelo/plotting/simulation.py b/scvelo/plotting/simulation.py
index a76817a9..6f3ea5a9 100644
--- a/scvelo/plotting/simulation.py
+++ b/scvelo/plotting/simulation.py
@@ -1,10 +1,12 @@
-from ..tools.dynamical_model_utils import unspliced, mRNA, vectorize, tau_inv, get_vars
-from .utils import make_dense
-
import numpy as np
+
import matplotlib.pyplot as pl
from matplotlib import rcParams
+from scvelo.core import SplicingDynamics
+from scvelo.tools.dynamical_model_utils import get_vars, tau_inv, unspliced, vectorize
+from .utils import make_dense
+
def get_dynamics(adata, key="fit", extrapolate=False, sorted=False, t=None):
alpha, beta, gamma, scaling, t_ = get_vars(adata, key=key)
@@ -18,8 +20,9 @@ def get_dynamics(adata, key="fit", extrapolate=False, sorted=False, t=None):
t = adata.obs[f"{key}_t"].values if key == "true" else adata.layers[f"{key}_t"]
tau, alpha, u0, s0 = vectorize(np.sort(t) if sorted else t, t_, alpha, beta, gamma)
- ut, st = mRNA(tau, u0, s0, alpha, beta, gamma)
-
+ ut, st = SplicingDynamics(
+ alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0]
+ ).get_solution(tau)
return alpha, ut, st
@@ -51,17 +54,26 @@ def compute_dynamics(
tau, alpha, u0, s0 = vectorize(np.sort(t) if sort else t, t_, alpha, beta, gamma)
- ut, st = mRNA(tau, u0, s0, alpha, beta, gamma)
+ ut, st = SplicingDynamics(
+ alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0]
+ ).get_solution(tau, stacked=False)
ut, st = ut * scaling + u0_offset, st + s0_offset
return alpha, ut, st
def show_full_dynamics(
- adata, basis, key="true", use_raw=False, linewidth=1, show_assignments=None, ax=None
+ adata,
+ basis,
+ key="true",
+ use_raw=False,
+ linewidth=1,
+ linecolor=None,
+ show_assignments=None,
+ ax=None,
):
if ax is None:
ax = pl.gca()
- color = "grey" if key == "true" else "purple"
+ color = linecolor if linecolor else "grey" if key == "true" else "purple"
linewidth = 0.5 * linewidth if key == "true" else linewidth
label = "learned dynamics" if key == "fit" else "true dynamics"
line = None
@@ -115,7 +127,7 @@ def simulation(
colors=None,
**kwargs,
):
- from ..tools.utils import make_dense
+ from scvelo.tools.utils import make_dense
from .scatter import scatter
if ykey is None:
diff --git a/scvelo/plotting/summary.py b/scvelo/plotting/summary.py
index 57c9ecfb..bf40914f 100644
--- a/scvelo/plotting/summary.py
+++ b/scvelo/plotting/summary.py
@@ -1,13 +1,12 @@
-from ..tools.dynamical_model import latent_time
-from ..tools.velocity_pseudotime import velocity_pseudotime
-from ..tools.rank_velocity_genes import rank_velocity_genes
-from ..tools.score_genes_cell_cycle import score_genes_cell_cycle
+import numpy as np
+import pandas as pd
-from .utils import make_unique_list
+from scvelo.tools.dynamical_model import latent_time
+from scvelo.tools.rank_velocity_genes import rank_velocity_genes
+from scvelo.tools.score_genes_cell_cycle import score_genes_cell_cycle
+from scvelo.tools.velocity_pseudotime import velocity_pseudotime
from .gridspec import GridSpec
-
-import pandas as pd
-import numpy as np
+from .utils import make_unique_list
def summary(adata, basis="umap", color="clusters", n_top_genes=12, var_names=None):
diff --git a/scvelo/plotting/utils.py b/scvelo/plotting/utils.py
index 5f30d3d1..7f6ec1b7 100644
--- a/scvelo/plotting/utils.py
+++ b/scvelo/plotting/utils.py
@@ -1,28 +1,27 @@
-from .. import settings
-from .. import logging as logg
-from .. import AnnData
-from ..preprocessing.moments import get_connectivities
-from ..tools.utils import strings_to_categoricals
-from . import palettes
-
import os
+from collections import abc
+
+from cycler import Cycler, cycler
+
import numpy as np
import pandas as pd
+from pandas import Index
+from scipy import stats
+from scipy.sparse import issparse
+
import matplotlib.pyplot as pl
-from matplotlib.ticker import MaxNLocator
-from mpl_toolkits.axes_grid1.inset_locator import inset_axes
-from matplotlib.colors import is_color_like, ListedColormap, to_rgb, cnames
+import matplotlib.transforms as tx
+from matplotlib import patheffects, rcParams
from matplotlib.collections import LineCollection
+from matplotlib.colors import cnames, is_color_like, ListedColormap, to_rgb
from matplotlib.gridspec import SubplotSpec
-from matplotlib import patheffects
-import matplotlib.transforms as tx
-from matplotlib import rcParams
-from pandas import unique, Index
-from scipy.sparse import issparse
-from scipy.stats import pearsonr
-from cycler import Cycler, cycler
-from collections import abc
+from matplotlib.ticker import MaxNLocator
+from mpl_toolkits.axes_grid1.inset_locator import inset_axes
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.tools.utils import strings_to_categoricals
+from . import palettes
"""helper functions"""
@@ -48,7 +47,7 @@ def is_view(adata):
def is_categorical(data, c=None):
- from pandas.api.types import is_categorical as cat
+ from pandas.api.types import is_categorical_dtype as cat
if c is None:
return cat(data) # if data is categorical/array
@@ -81,7 +80,9 @@ def is_list_of_str(key, max_len=None):
def is_list_of_list(lst):
- return lst is not None and any(isinstance(l, list) for l in lst)
+ return lst is not None and any(
+ isinstance(list_element, list) for list_element in lst
+ )
def is_list_of_int(lst):
@@ -112,7 +113,9 @@ def get_ax(ax, show=None, figsize=None, dpi=None, projection=None):
figsize, _ = get_figure_params(figsize)
if ax is None:
projection = "3d" if projection == "3d" else None
- ax = pl.figure(None, figsize, dpi=dpi).gca(projection=projection)
+ _, ax = pl.subplots(
+ figsize=figsize, dpi=dpi, subplot_kw={"projection": projection}
+ )
elif isinstance(ax, SubplotSpec):
geo = ax.get_geometry()
if show is None:
@@ -347,7 +350,7 @@ def default_color_map(adata, c):
try:
if np.min(c) in [-1, 0, False] and np.max(c) in [1, True]:
cmap = "viridis_r"
- except:
+ except Exception:
cmap = None
return cmap
@@ -461,7 +464,7 @@ def set_artist_frame(ax, length=0.2, figsize=None):
figsize = rcParams["figure.figsize"] if figsize is None else figsize
aspect_ratio = figsize[0] / figsize[1]
ax.xaxis.set_label_coords(length * 0.45, -0.035)
- ax.yaxis.set_label_coords(-0.0175, length * aspect_ratio * 0.45)
+ ax.yaxis.set_label_coords(-0.025, length * aspect_ratio * 0.45)
ax.xaxis.label.set_size(ax.xaxis.label.get_size() / 1.2)
ax.yaxis.label.set_size(ax.yaxis.label.get_size() / 1.2)
@@ -539,6 +542,7 @@ def set_legend(
legend_fontweight,
legend_fontsize,
legend_fontoutline,
+ legend_align_text,
groups,
):
"""
@@ -577,10 +581,14 @@ def set_legend(
text = ax.text(x_pos, y_pos, label, path_effects=pe, **kwargs)
texts.append(text)
- # todo: adjust text positions to minimize overlaps,
- # e.g. using https://github.com/Phlya/adjustText
- # from adjustText import adjust_text
- # adjust_text(texts, ax=ax)
+ if legend_align_text:
+ autoalign = "y" if legend_align_text is True else legend_align_text
+ try:
+ from adjustText import adjust_text as adj_text
+
+ adj_text(texts, autoalign=autoalign, text_from_points=False, ax=ax)
+ except ImportError:
+ print("Please `pip install adjustText` for auto-aligning texts")
else:
for idx, label in enumerate(categories):
@@ -599,6 +607,16 @@ def set_legend(
ax.legend(loc=legend_loc, **kwargs)
+def set_margin(ax, x, y, add_margin):
+ add_margin = 0.1 if add_margin is True else add_margin
+ xmin, xmax = np.min(x), np.max(x)
+ ymin, ymax = np.min(y), np.max(y)
+ xmargin = (xmax - xmin) * add_margin
+ ymargin = (ymax - ymin) * add_margin
+ ax.set_xlim(xmin - xmargin, xmax + xmargin)
+ ax.set_ylim(ymin - ymargin, ymax + ymargin)
+
+
"""get color values"""
@@ -640,7 +658,7 @@ def interpret_colorkey(adata, c=None, layer=None, perc=None, use_raw=None):
if is_categorical(adata, c):
c = get_colors(adata, c)
elif isinstance(c, str):
- if is_color_like(c) and not c in adata.var_names:
+ if is_color_like(c) and c not in adata.var_names:
pass
elif c in adata.obs.keys(): # color by observation key
c = adata.obs[c]
@@ -649,15 +667,22 @@ def interpret_colorkey(adata, c=None, layer=None, perc=None, use_raw=None):
): # by gene
if layer in adata.layers.keys():
if perc is None and any(
- l in layer for l in ["spliced", "unspliced", "Ms", "Mu", "velocity"]
+ layer_name in layer
+ for layer_name in ["spliced", "unspliced", "Ms", "Mu", "velocity"]
):
perc = [1, 99] # to ignore outliers in non-logarithmized layers
c = adata.obs_vector(c, layer=layer)
elif layer is not None and np.any(
- [l in layer or "X" in layer for l in adata.layers.keys()]
+ [
+ layer_name in layer or "X" in layer
+ for layer_name in adata.layers.keys()
+ ]
):
l_array = np.hstack(
- [adata.obs_vector(c, layer=l)[:, None] for l in adata.layers.keys()]
+ [
+ adata.obs_vector(c, layer=layer)[:, None]
+ for layer in adata.layers.keys()
+ ]
)
l_array = pd.DataFrame(l_array, columns=adata.layers.keys())
l_array.insert(0, "X", adata.obs_vector(c))
@@ -712,9 +737,10 @@ def set_colors_for_categorical_obs(adata, value_to_plot, palette=None):
a sequence of colors (in a format that can be understood by matplotlib,
eg. RGB, RGBS, hex, or a cycler object with key='color'
"""
- from .palettes import additional_colors
from matplotlib.colors import to_hex
+ from .palettes import additional_colors
+
color_key = f"{value_to_plot}_colors"
valid = True
categories = adata.obs[value_to_plot].cat.categories
@@ -724,19 +750,19 @@ def set_colors_for_categorical_obs(adata, value_to_plot, palette=None):
palette = palettes.default_26 if length <= 28 else palettes.default_64
if isinstance(palette, str) and palette in adata.uns:
palette = (
- adata.uns[palette].values()
+ [adata.uns[palette][c] for c in categories]
if isinstance(adata.uns[palette], dict)
else adata.uns[palette]
)
-
if palette is None and color_key in adata.uns:
+ color_keys = adata.uns[color_key]
# Check if colors already exist in adata.uns and if they are a valid palette
- _palette = []
- color_keys = (
- adata.uns[color_key].values()
- if isinstance(adata.uns[color_key], dict)
- else adata.uns[color_key]
- )
+ if isinstance(color_keys, np.ndarray) and isinstance(color_keys[0], dict):
+ adata.uns[color_key] = adata.uns[color_key][0]
+ # Flatten the dict to a list (mainly for anndata compatibilities)
+ if isinstance(adata.uns[color_key], dict):
+ adata.uns[color_key] = [adata.uns[color_key][c] for c in categories]
+ color_keys = adata.uns[color_key]
for color in color_keys:
if not is_color_like(color):
# check if valid color translate to a hex color value
@@ -760,6 +786,13 @@ def set_colors_for_categorical_obs(adata, value_to_plot, palette=None):
colors_list = [to_hex(x) for x in cmap(np.linspace(0, 1, length))]
else:
+ # check if palette is an array of length n_obs
+ if isinstance(palette, (list, np.ndarray)) or is_categorical(palette):
+ if len(adata.obs[value_to_plot]) == len(palette):
+ cats = pd.Categorical(adata.obs[value_to_plot])
+ colors = pd.Categorical(palette)
+ if len(cats) == len(colors):
+ palette = dict(zip(cats, colors))
# check if palette is as dict and convert it to an ordered list
if isinstance(palette, dict):
palette = [palette[c] for c in categories]
@@ -884,8 +917,9 @@ def rgb_custom_colormap(colors=None, alpha=None, N=256):
Returns
-------
- A ListedColormap
+ :class:`~matplotlib.colors.ListedColormap`
"""
+
if colors is None:
colors = ["royalblue", "white", "forestgreen"]
c = []
@@ -906,9 +940,11 @@ def rgb_custom_colormap(colors=None, alpha=None, N=256):
n = int(N / ints)
for j in range(ints):
+ start = n * j
+ end = n * (j + 1)
for i in range(3):
- vals[n * j : n * (j + 1), i] = np.linspace(c[j][i], c[j + 1][i], n)
- vals[n * j : n * (j + 1), -1] = np.linspace(alpha[j], alpha[j + 1], n)
+ vals[start:end, i] = np.linspace(c[j][i], c[j + 1][i], n)
+ vals[start:end, -1] = np.linspace(alpha[j], alpha[j + 1], n)
return ListedColormap(vals)
@@ -925,6 +961,8 @@ def savefig_or_show(writekey=None, show=None, dpi=None, ext=None, save=None):
save = save.replace(try_ext, "")
break
# append it
+ if "/" in save:
+ writekey = None
writekey = (
f"{writekey}_{save}" if writekey is not None and len(writekey) > 0 else save
)
@@ -955,15 +993,17 @@ def savefig_or_show(writekey=None, show=None, dpi=None, ext=None, save=None):
os.makedirs(settings.figdir)
if ext is None:
ext = settings.file_format_figs
- filename = f"{settings.figdir}{settings.plot_prefix}{writekey}"
+ filepath = f"{settings.figdir}{settings.plot_prefix}{writekey}"
+ if "/" in writekey:
+ filepath = f"{writekey}"
try:
- filename += f"{settings.plot_suffix}.{ext}"
+ filename = filepath + f"{settings.plot_suffix}.{ext}"
pl.savefig(filename, dpi=dpi, bbox_inches="tight")
- except: # save as .png if .pdf is not feasible (e.g. specific streamplots)
- logg.msg(f"figure cannot be saved as {ext}, using png instead.", v=1)
- filename = f"{settings.figdir}{settings.plot_prefix}{writekey}"
- filename += f"{settings.plot_suffix}.png"
+ except Exception:
+ # save as .png if .pdf is not feasible (e.g. specific streamplots)
+ filename = filepath + f"{settings.plot_suffix}.png"
pl.savefig(filename, dpi=dpi, bbox_inches="tight")
+ logg.msg(f"figure cannot be saved as {ext}, using png instead.", v=1)
logg.msg("saving figure to file", filename, v=1)
if show:
pl.show()
@@ -1008,14 +1048,14 @@ def plot_linfit(
if isinstance(add_linfit, str)
else color
if isinstance(color, str)
- else "grey"
+ else "k"
)
xnew = np.linspace(np.min(x), np.max(x) * 1.02)
ax.plot(xnew, offset + xnew * slope, linewidth=linewidth, color=color)
if add_legend:
kwargs = dict(ha="left", va="top", fontsize=fontsize)
bbox = dict(boxstyle="round", facecolor="wheat", alpha=0.2)
- txt = r"$\rho = $" + f"{np.round(pearsonr(x, y)[0], 2)}"
+ txt = r"$\rho = $" + f"{np.round(stats.pearsonr(x, y)[0], 2)}"
ax.text(0.05, 0.95, txt, transform=ax.transAxes, bbox=bbox, **kwargs)
@@ -1053,7 +1093,7 @@ def plot_polyfit(
if isinstance(add_polyfit, str)
else color
if isinstance(color, str)
- else "grey"
+ else "k"
)
xnew = np.linspace(np.min(x), np.max(x), num=100)
@@ -1128,16 +1168,25 @@ def plot_velocity_fits(
if "true_alpha" in adata.var.keys() and (
vkey is not None and "true_dynamics" in vkey
):
- line, fit = show_full_dynamics(adata, basis, "true", use_raw, linewidth, ax=ax)
+ line, fit = show_full_dynamics(
+ adata,
+ basis,
+ key="true",
+ use_raw=use_raw,
+ linewidth=linewidth,
+ linecolor=linecolor,
+ ax=ax,
+ )
fits.append(fit)
lines.append(line)
if "fit_alpha" in adata.var.keys() and (vkey is None or "dynamics" in vkey):
line, fit = show_full_dynamics(
adata,
basis,
- "fit",
- use_raw,
- linewidth,
+ key="fit",
+ use_raw=use_raw,
+ linewidth=linewidth,
+ linecolor=linecolor,
show_assignments=show_assignments,
ax=ax,
)
@@ -1359,7 +1408,7 @@ def hist(
Returns
-------
- If `show==False` a `matplotlib.Axis`
+ If `show==False` a `matplotlib.Axis`
"""
if ax is None:
@@ -1418,11 +1467,11 @@ def hist(
ax.fill_between(bins, 0, kde_bins, alpha=0.4, color=ci, label=li)
ylim = np.min(kde_bins) if ylim is None else ylim
if hist:
- ci, li = colors[i], labels[i] if labels is not None else None
+ ci, li = colors[i], labels[i] if labels is not None and not kde else None
kwargs.update({"color": ci, "label": li})
try:
ax.hist(x_vals, bins=bins, alpha=alpha, density=normed, **kwargs)
- except:
+ except Exception:
ax.hist(x_vals, bins=bins, alpha=alpha, **kwargs)
if xlabel is None:
xlabel = ""
@@ -1465,20 +1514,19 @@ def log_fmt(x, pos):
pdf = [pdf] if isinstance(pdf, str) else pdf
if pdf is not None:
fits = []
- for i, pd in enumerate(pdf):
- from scipy import stats
-
+ for i, pdf_name in enumerate(pdf):
xt = ax.get_xticks()
xmin, xmax = min(xt), max(xt)
lnspc = np.linspace(xmin, xmax, len(bins))
- if "(" in pd: # used passed parameters
- args, pd = eval(pd[pd.rfind("(") :]), pd[: pd.rfind("(")]
+ if "(" in pdf_name: # used passed parameters
+ start = pdf_name.rfind("(")
+ args, pdf_name = eval(pdf_name[start:]), pdf_name[:start]
else: # fit parameters
- args = eval(f"stats.{pd}.fit(x_vals)")
- pd_vals = eval(f"stats.{pd}.pdf(lnspc, *args)")
- logg.info("Fitting", pd, np.round(args, 4), ".")
- fit = ax.plot(lnspc, pd_vals, label=pd, color=colors[i])
+ args = getattr(stats, pdf_name).fit(x_vals)
+ pd_vals = getattr(stats, pdf_name).pdf(lnspc, *args)
+ logg.info("Fitting", pdf_name, np.round(args, 4), ".")
+ fit = ax.plot(lnspc, pd_vals, label=pdf_name, color=colors[i])
fits.extend(fit)
ax.legend(handles=fits, labels=pdf, fontsize=legend_fontsize)
@@ -1505,7 +1553,7 @@ def plot(
dpi=None,
show=True,
):
- ax = pl.figure(None, figsize, dpi=dpi) if ax is None else ax
+ ax, show = get_ax(ax, show, figsize, dpi)
arrays = np.array(arrays)
arrays = (
arrays if isinstance(arrays, (list, tuple)) or arrays.ndim > 1 else [arrays]
@@ -1517,21 +1565,19 @@ def plot(
for i, array in enumerate(arrays):
X = array[np.isfinite(array)]
X = X / np.max(X) if normalize else X
- pl.plot(X, color=colors[i], label=labels[i] if labels is not None else None)
+ ax.plot(X, color=colors[i], label=labels[i] if labels is not None else None)
- pl.xlabel(xlabel if xlabel is not None else "")
- pl.ylabel(ylabel if xlabel is not None else "")
+ ax.set_xlabel(xlabel if xlabel is not None else "")
+ ax.set_ylabel(ylabel if xlabel is not None else "")
if labels is not None:
- pl.legend()
+ ax.legend()
if xscale is not None:
- pl.xscale(xscale)
+ ax.xscale(xscale)
if yscale is not None:
- pl.yscale(yscale)
+ ax.yscale(yscale)
if not show:
return ax
- else:
- pl.show()
def fraction_timeseries(
@@ -1627,10 +1673,10 @@ def make_unique_valid_list(adata, keys):
def get_temporal_connectivities(adata, tkey, n_convolve=30):
- from ..tools.velocity_graph import vals_to_csr
- from ..tools.utils import normalize
+ from scvelo.tools.utils import normalize
+ from scvelo.tools.velocity_graph import vals_to_csr
- # from ..tools.utils import get_indices
+ # from scvelo.tools.utils import get_indices
# c_idx = get_indices(get_connectivities(adata, recurse_neighbors=True))[0]
# lspace = np.linspace(0, len(c_idx) - 1, len(c_idx), dtype=int)
# c_idx = np.hstack([c_idx, lspace[:, None]])
diff --git a/scvelo/plotting/velocity.py b/scvelo/plotting/velocity.py
index 27f67a5f..d68229a9 100644
--- a/scvelo/plotting/velocity.py
+++ b/scvelo/plotting/velocity.py
@@ -1,21 +1,21 @@
-from .. import settings
-from ..preprocessing.moments import second_order_moments
-from ..tools.rank_velocity_genes import rank_velocity_genes
+import numpy as np
+import pandas as pd
+from scipy.sparse import issparse
+
+import matplotlib.pyplot as pl
+from matplotlib import rcParams
+
+from scvelo.preprocessing.moments import second_order_moments
+from scvelo.tools.rank_velocity_genes import rank_velocity_genes
from .scatter import scatter
from .utils import (
- savefig_or_show,
default_basis,
default_size,
get_basis,
get_figure_params,
+ savefig_or_show,
)
-import numpy as np
-import pandas as pd
-import matplotlib.pyplot as pl
-from matplotlib import rcParams
-from scipy.sparse import issparse
-
def velocity(
adata,
@@ -193,8 +193,8 @@ def velocity(
)
# velocity and expression plots
- for l, layer in enumerate(layers):
- ax = pl.subplot(gs[v * nplts + l + 1])
+ for layer_id, layer in enumerate(layers):
+ ax = pl.subplot(gs[v * nplts + layer_id + 1])
title = "expression" if layer in ["X", skey] else layer
# _kwargs = {} if title == 'expression' else kwargs
cmap = color_map
diff --git a/scvelo/plotting/velocity_embedding.py b/scvelo/plotting/velocity_embedding.py
index 26725633..b9d86ad4 100644
--- a/scvelo/plotting/velocity_embedding.py
+++ b/scvelo/plotting/velocity_embedding.py
@@ -1,13 +1,30 @@
-from ..tools.velocity_embedding import velocity_embedding as compute_velocity_embedding
-from ..tools.utils import groups_to_bool
-from .utils import *
-from .scatter import scatter
-from .docs import doc_scatter, doc_params
+import numpy as np
+import matplotlib.pyplot as pl
from matplotlib import rcParams
from matplotlib.colors import is_color_like
-import matplotlib.pyplot as pl
-import numpy as np
+
+from scvelo.tools.utils import groups_to_bool
+from scvelo.tools.velocity_embedding import (
+ velocity_embedding as compute_velocity_embedding,
+)
+from .docs import doc_params, doc_scatter
+from .scatter import scatter
+from .utils import (
+ default_arrow,
+ default_basis,
+ default_color,
+ default_color_map,
+ default_size,
+ get_ax,
+ get_components,
+ get_figure_params,
+ interpret_colorkey,
+ make_unique_list,
+ make_unique_valid_list,
+ savefig_or_show,
+ velocity_embedding_changed,
+)
@doc_params(scatter=doc_scatter)
@@ -70,8 +87,9 @@ def velocity_embedding(
Returns
-------
- `matplotlib.Axis` if `show==False`
+ `matplotlib.Axis` if `show==False`
"""
+
if vkey == "all":
lkeys = list(adata.layers.keys())
vkey = [key for key in lkeys if "velocity" in key and "_u" not in key]
@@ -159,8 +177,6 @@ def velocity_embedding(
return ax
else:
- if projection == "3d":
- from mpl_toolkits.mplot3d import Axes3D
ax, show = get_ax(ax, show, figsize, dpi, projection)
color, layer, vkey, basis = colors[0], layers[0], vkeys[0], bases[0]
diff --git a/scvelo/plotting/velocity_embedding_grid.py b/scvelo/plotting/velocity_embedding_grid.py
index d81d77ce..c8bd86f9 100644
--- a/scvelo/plotting/velocity_embedding_grid.py
+++ b/scvelo/plotting/velocity_embedding_grid.py
@@ -1,14 +1,27 @@
-from ..tools.velocity_embedding import quiver_autoscale, velocity_embedding
-from ..tools.utils import groups_to_bool
-from .utils import *
-from .scatter import scatter
-from .docs import doc_scatter, doc_params
-
-from sklearn.neighbors import NearestNeighbors
+import numpy as np
from scipy.stats import norm as normal
-from matplotlib import rcParams
+from sklearn.neighbors import NearestNeighbors
+
import matplotlib.pyplot as pl
-import numpy as np
+from matplotlib import rcParams
+
+from scvelo.tools.utils import groups_to_bool
+from scvelo.tools.velocity_embedding import quiver_autoscale, velocity_embedding
+from .docs import doc_params, doc_scatter
+from .scatter import scatter
+from .utils import (
+ default_arrow,
+ default_basis,
+ default_color,
+ default_size,
+ get_ax,
+ get_basis,
+ get_components,
+ get_figure_params,
+ make_unique_list,
+ savefig_or_show,
+ velocity_embedding_changed,
+)
def compute_velocity_on_grid(
@@ -165,8 +178,9 @@ def velocity_embedding_grid(
Returns
-------
- `matplotlib.Axis` if `show==False`
+ `matplotlib.Axis` if `show==False`
"""
+
basis = default_basis(adata, **kwargs) if basis is None else get_basis(adata, basis)
if vkey == "all":
lkeys = list(adata.layers.keys())
diff --git a/scvelo/plotting/velocity_embedding_stream.py b/scvelo/plotting/velocity_embedding_stream.py
index 05a6d2a1..5a7acd89 100644
--- a/scvelo/plotting/velocity_embedding_stream.py
+++ b/scvelo/plotting/velocity_embedding_stream.py
@@ -1,13 +1,25 @@
-from ..tools.velocity_embedding import velocity_embedding
-from ..tools.utils import groups_to_bool
-from .utils import *
-from .velocity_embedding_grid import compute_velocity_on_grid
-from .scatter import scatter
-from .docs import doc_scatter, doc_params
+import numpy as np
-from matplotlib import rcParams
import matplotlib.pyplot as pl
-import numpy as np
+from matplotlib import rcParams
+
+from scvelo.tools.utils import groups_to_bool
+from scvelo.tools.velocity_embedding import velocity_embedding
+from .docs import doc_params, doc_scatter
+from .scatter import scatter
+from .utils import (
+ default_basis,
+ default_color,
+ default_size,
+ get_ax,
+ get_basis,
+ get_components,
+ get_figure_params,
+ make_unique_list,
+ savefig_or_show,
+ velocity_embedding_changed,
+)
+from .velocity_embedding_grid import compute_velocity_on_grid
@doc_params(scatter=doc_scatter)
@@ -15,11 +27,15 @@ def velocity_embedding_stream(
adata,
basis=None,
vkey="velocity",
- density=None,
+ density=2,
smooth=None,
min_mass=None,
cutoff_perc=None,
arrow_color=None,
+ arrow_size=1,
+ arrow_style="-|>",
+ max_length=4,
+ integration_direction="both",
linewidth=None,
n_neighbors=None,
recompute=None,
@@ -62,8 +78,11 @@ def velocity_embedding_stream(
---------
adata: :class:`~anndata.AnnData`
Annotated data matrix.
- density: `float` (default: 1)
- Amount of velocities to show - 0 none to 1 all
+ density: `float` (default: 2)
+ Controls the closeness of streamlines. When density = 2 (default), the domain
+ is divided into a 60x60 grid, whereas density linearly scales this grid.
+ Each cell in the grid can have, at most, one traversing streamline.
+ For different densities in each direction, use a tuple (density_x, density_y).
smooth: `float` (default: 0.5)
Multiplication factor for scale in Gaussian kernel around grid point.
min_mass: `float` (default: 1)
@@ -73,18 +92,30 @@ def velocity_embedding_stream(
If set, mask small velocities below a percentile threshold (between 0 and 100).
linewidth: `float` (default: 1)
Line width for streamplot.
+ arrow_color: `str` or 2D array (default: 'k')
+ The streamline color. If given an array, it must have the same shape as u and v.
+ arrow_size: `float` (default: 1)
+ Scaling factor for the arrow size.
+ arrow_style: `str` (default: '-|>')
+ Arrow style specification, '-|>' or '->'.
+ max_length: `float` (default: 4)
+ Maximum length of streamline in axes coordinates.
+ integration_direction: `str` (default: 'both')
+ Integrate the streamline in 'forward', 'backward' or 'both' directions.
n_neighbors: `int` (default: None)
Number of neighbors to consider around grid point.
X: `np.ndarray` (default: None)
- Embedding grid point coordinates
+ Embedding coordinates. Using `adata.obsm['X_umap']` per default.
V: `np.ndarray` (default: None)
- Embedding grid velocity coordinates
+ Embedding velocity coordinates. Using `adata.obsm['velocity_umap']` per default.
+
{scatter}
Returns
-------
- `matplotlib.Axis` if `show==False`
+ `matplotlib.Axis` if `show==False`
"""
+
basis = default_basis(adata, **kwargs) if basis is None else get_basis(adata, basis)
if vkey == "all":
lkeys = list(adata.layers.keys())
@@ -148,6 +179,17 @@ def velocity_embedding_stream(
"save": False,
}
+ stream_kwargs = {
+ "linewidth": linewidth,
+ "density": density or 2,
+ "zorder": 3,
+ "arrow_color": arrow_color or "k",
+ "arrowsize": arrow_size or 1,
+ "arrowstyle": arrow_style or "-|>",
+ "maxlength": max_length or 4,
+ "integration_direction": integration_direction or "both",
+ }
+
multikey = (
colors
if len(colors) > 1
@@ -175,11 +217,9 @@ def velocity_embedding_stream(
ax.append(
velocity_embedding_stream(
adata,
- density=density,
size=size,
smooth=smooth,
n_neighbors=n_neighbors,
- linewidth=linewidth,
ax=pl.subplot(gs),
color=colors[i] if len(colors) > 1 else color,
layer=layers[i] if len(layers) > 1 else layer,
@@ -188,6 +228,7 @@ def velocity_embedding_stream(
X_grid=None if len(vkeys) > 1 else X_grid,
V_grid=None if len(vkeys) > 1 else V_grid,
**scatter_kwargs,
+ **stream_kwargs,
**kwargs,
)
)
@@ -197,19 +238,14 @@ def velocity_embedding_stream(
else:
ax, show = get_ax(ax, show, figsize, dpi)
- density = 1 if density is None else density
- stream_kwargs = {
- "linewidth": linewidth,
- "density": 2 * density,
- "zorder": 3,
- "color": "k" if arrow_color is None else arrow_color,
- }
+
for arg in list(kwargs):
if arg in stream_kwargs:
stream_kwargs.update({arg: kwargs[arg]})
else:
scatter_kwargs.update({arg: kwargs[arg]})
+ stream_kwargs["color"] = stream_kwargs.pop("arrow_color", "k")
ax.streamplot(X_grid[0], X_grid[1], V_grid[0], V_grid[1], **stream_kwargs)
size = 8 * default_size(adata) if size is None else size
diff --git a/scvelo/plotting/velocity_graph.py b/scvelo/plotting/velocity_graph.py
index 9d5cba0f..e0612156 100644
--- a/scvelo/plotting/velocity_graph.py
+++ b/scvelo/plotting/velocity_graph.py
@@ -1,14 +1,21 @@
-from .. import settings
-from ..preprocessing.neighbors import get_neighs
-from ..tools.transition_matrix import transition_matrix
-from .utils import savefig_or_show, default_basis, get_components
-from .utils import get_basis, groups_to_bool, default_size
-from .scatter import scatter
-from .docs import doc_scatter, doc_params
-
import warnings
+
import numpy as np
-from scipy.sparse import issparse, csr_matrix
+from scipy.sparse import csr_matrix, issparse
+
+from scvelo import settings
+from scvelo.preprocessing.neighbors import get_neighs
+from scvelo.tools.transition_matrix import transition_matrix
+from .docs import doc_params, doc_scatter
+from .scatter import scatter
+from .utils import (
+ default_basis,
+ default_size,
+ get_basis,
+ get_components,
+ groups_to_bool,
+ savefig_or_show,
+)
@doc_params(scatter=doc_scatter)
@@ -59,8 +66,9 @@ def velocity_graph(
Returns
-------
- `matplotlib.Axis` if `show==False`
+ `matplotlib.Axis` if `show==False`
"""
+
basis = default_basis(adata, **kwargs) if basis is None else get_basis(adata, basis)
kwargs.update(
{
@@ -76,7 +84,7 @@ def velocity_graph(
)
ax = scatter(adata, layer=layer, color=color, size=size, ax=ax, zorder=0, **kwargs)
- from networkx import Graph, DiGraph
+ from networkx import DiGraph, Graph
if which_graph in {"neighbors", "connectivities"}:
T = get_neighs(adata, "connectivities").copy()
@@ -165,11 +173,12 @@ def draw_networkx_edges(
):
"""Draw the edges of the graph G. Adjusted from networkx."""
try:
+ from numbers import Number
+
import matplotlib.pyplot as plt
- from matplotlib.colors import colorConverter, Colormap, Normalize
from matplotlib.collections import LineCollection
+ from matplotlib.colors import colorConverter, Colormap, Normalize
from matplotlib.patches import FancyArrowPatch
- from numbers import Number
except ImportError:
raise ImportError("Matplotlib required for draw()")
except RuntimeError:
diff --git a/scvelo/pp.py b/scvelo/pp.py
index 4db20b2e..c0870c14 100644
--- a/scvelo/pp.py
+++ b/scvelo/pp.py
@@ -1 +1 @@
-from scvelo.preprocessing import *
+from scvelo.preprocessing import * # noqa
diff --git a/scvelo/preprocessing/__init__.py b/scvelo/preprocessing/__init__.py
index 863f64a9..1ba31917 100644
--- a/scvelo/preprocessing/__init__.py
+++ b/scvelo/preprocessing/__init__.py
@@ -1,4 +1,27 @@
-from .utils import show_proportions, cleanup, filter_genes, filter_genes_dispersion
-from .utils import normalize_per_cell, filter_and_normalize, log1p, recipe_velocity
-from .neighbors import pca, neighbors, remove_duplicate_cells
from .moments import moments
+from .neighbors import neighbors, pca, remove_duplicate_cells
+from .utils import (
+ cleanup,
+ filter_and_normalize,
+ filter_genes,
+ filter_genes_dispersion,
+ log1p,
+ normalize_per_cell,
+ recipe_velocity,
+ show_proportions,
+)
+
+__all__ = [
+ "cleanup",
+ "filter_and_normalize",
+ "filter_genes",
+ "filter_genes_dispersion",
+ "log1p",
+ "moments",
+ "neighbors",
+ "normalize_per_cell",
+ "pca",
+ "recipe_velocity",
+ "remove_duplicate_cells",
+ "show_proportions",
+]
diff --git a/scvelo/preprocessing/moments.py b/scvelo/preprocessing/moments.py
index 312ad9ed..494eef17 100644
--- a/scvelo/preprocessing/moments.py
+++ b/scvelo/preprocessing/moments.py
@@ -1,10 +1,10 @@
-from .. import settings
-from .. import logging as logg
-from .utils import not_yet_normalized, normalize_per_cell
-from .neighbors import neighbors, get_connectivities, get_n_neighs, verify_neighbors
-
-from scipy.sparse import csr_matrix, issparse
import numpy as np
+from scipy.sparse import csr_matrix, issparse
+
+from scvelo import logging as logg
+from scvelo import settings
+from .neighbors import get_connectivities, get_n_neighs, neighbors, verify_neighbors
+from .utils import normalize_per_cell, not_yet_normalized
def moments(
@@ -48,12 +48,12 @@ def moments(
Returns
-------
- Returns or updates `adata` with the attributes
Ms: `.layers`
dense matrix with first order moments of spliced counts.
Mu: `.layers`
dense matrix with first order moments of unspliced counts.
"""
+
adata = data.copy() if copy else data
layers = [layer for layer in {"spliced", "unspliced"} if layer in adata.layers]
@@ -114,6 +114,7 @@ def second_order_moments(adata, adjusted=False):
Mss: Second order moments for spliced abundances
Mus: Second order moments for spliced with unspliced abundances
"""
+
if "neighbors" not in adata.uns:
raise ValueError(
"You need to run `pp.neighbors` first to compute a neighborhood graph."
@@ -143,6 +144,7 @@ def second_order_moments_u(adata):
-------
Muu: Second order moments for unspliced abundances
"""
+
if "neighbors" not in adata.uns:
raise ValueError(
"You need to run `pp.neighbors` first to compute a neighborhood graph."
@@ -186,10 +188,12 @@ def get_moments(
Whether to compute centered (=variance) or uncentered second order moments.
mode: `'connectivities'` or `'distances'` (default: `'connectivities'`)
Distance metric to use for moment computation.
+
Returns
-------
Mx: first or second order moments
"""
+
if "neighbors" not in adata.uns:
raise ValueError(
"You need to run `pp.neighbors` first to compute a neighborhood graph."
diff --git a/scvelo/preprocessing/neighbors.py b/scvelo/preprocessing/neighbors.py
index a0063e31..b7f9e850 100644
--- a/scvelo/preprocessing/neighbors.py
+++ b/scvelo/preprocessing/neighbors.py
@@ -1,13 +1,17 @@
import warnings
+from collections import Counter
+
import numpy as np
+import pandas as pd
+from scipy.sparse import coo_matrix, issparse
+
from anndata import AnnData
from scanpy import Neighbors
from scanpy.preprocessing import pca
-from scipy.sparse import issparse, coo_matrix
-from .utils import get_initial_size
-from .. import logging as logg
-from .. import settings
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.core import get_initial_size
def neighbors(
@@ -71,16 +75,16 @@ def neighbors(
Number of threads to be used (for runtime).
copy
Return a copy instead of writing to adata.
+
Returns
-------
- Depending on `copy`, updates or returns `adata` with the following:
- connectivities : sparse matrix (`.uns['neighbors']`, dtype `float32`)
- Weighted adjacency matrix of the neighborhood graph of data
+ connectivities : `.obsp`
+ Sparse weighted adjacency matrix of the neighborhood graph of data
points. Weights should be interpreted as connectivities.
- distances : sparse matrix (`.uns['neighbors']`, dtype `float32`)
- Instead of decaying weights, this stores distances for each pair of
- neighbors.
+ distances : `.obsp`
+ Sparse matrix of distances for each pair of neighbors.
"""
+
adata = adata.copy() if copy else adata
if use_rep is None:
@@ -179,7 +183,7 @@ def neighbors(
adata.obsp["connectivities"] = neighbors.connectivities
adata.uns["neighbors"]["connectivities_key"] = "connectivities"
adata.uns["neighbors"]["distances_key"] = "distances"
- except:
+ except Exception:
adata.uns["neighbors"]["distances"] = neighbors.distances
adata.uns["neighbors"]["connectivities"] = neighbors.connectivities
@@ -427,15 +431,15 @@ def compute_connectivities_umap(
def get_duplicate_cells(data):
if isinstance(data, AnnData):
X = data.X
- l = list(np.sum(np.abs(data.obsm["X_pca"]), 1) + get_initial_size(data))
+ lst = list(np.sum(np.abs(data.obsm["X_pca"]), 1) + get_initial_size(data))
else:
X = data
- l = list(np.sum(X, 1).A1 if issparse(X) else np.sum(X, 1))
+ lst = list(np.sum(X, 1).A1 if issparse(X) else np.sum(X, 1))
- l_set = set(l)
idx_dup = []
- if len(l_set) < len(l):
- idx_dup = np.array([i for i, x in enumerate(l) if l.count(x) > 1])
+ if len(set(lst)) < len(lst):
+ vals = [val for val, count in Counter(lst).items() if count > 1]
+ idx_dup = np.where(pd.Series(lst).isin(vals))[0]
X_new = np.array(X[idx_dup].A if issparse(X) else X[idx_dup])
sorted_idx = np.lexsort(X_new.T)
@@ -443,7 +447,7 @@ def get_duplicate_cells(data):
row_mask = np.invert(np.append([True], np.any(np.diff(sorted_data, axis=0), 1)))
idx = sorted_idx[row_mask]
- idx_dup = idx_dup[idx]
+ idx_dup = np.array(idx_dup)[idx]
return idx_dup
@@ -456,4 +460,5 @@ def remove_duplicate_cells(adata):
mask[idx_duplicates] = 0
logg.info("Removed", len(idx_duplicates), "duplicate cells.")
adata._inplace_subset_obs(mask)
- neighbors(adata)
+ if "neighbors" in adata.uns.keys():
+ neighbors(adata)
diff --git a/scvelo/preprocessing/utils.py b/scvelo/preprocessing/utils.py
index ca6ff3e0..fca3db11 100644
--- a/scvelo/preprocessing/utils.py
+++ b/scvelo/preprocessing/utils.py
@@ -1,172 +1,124 @@
+import warnings
+
import numpy as np
import pandas as pd
from scipy.sparse import issparse
from sklearn.utils import sparsefuncs
+
from anndata import AnnData
-import warnings
-from .. import logging as logg
+from scvelo import logging as logg
+from scvelo.core import cleanup as _cleanup
+from scvelo.core import get_initial_size as _get_initial_size
+from scvelo.core import get_size as _get_size
+from scvelo.core import set_initial_size as _set_initial_size
+from scvelo.core import show_proportions as _show_proportions
+from scvelo.core import sum
+from scvelo.core._anndata import verify_dtypes as _verify_dtypes
def sum_obs(A):
"""summation over axis 0 (obs) equivalent to np.sum(A, 0)"""
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- return A.sum(0).A1 if issparse(A) else np.sum(A, axis=0)
+
+ warnings.warn(
+ "`sum_obs` is deprecated since scVelo v0.2.4 and will be removed in a future "
+ "version. Please use `sum(A, axis=0)` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return sum(A, axis=0)
def sum_var(A):
"""summation over axis 1 (var) equivalent to np.sum(A, 1)"""
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- return A.sum(1).A1 if issparse(A) else np.sum(A, axis=1)
+ warnings.warn(
+ "`sum_var` is deprecated since scVelo v0.2.4 and will be removed in a future "
+ "version. Please use `sum(A, axis=1)` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
-def show_proportions(adata, layers=None, use_raw=True):
- """Proportions of spliced/unspliced abundances
+ return sum(A, axis=1)
- Arguments
- ---------
- adata: :class:`~anndata.AnnData`
- Annotated data matrix.
- Returns
- -------
- Prints the fractions of abundances.
- """
- if layers is None:
- layers = ["spliced", "unspliced", "ambigious"]
- layers_keys = [key for key in layers if key in adata.layers.keys()]
- counts_layers = [sum_var(adata.layers[key]) for key in layers_keys]
- if use_raw:
- size_key, obs = "initial_size_", adata.obs
- counts_layers = [
- obs[size_key + l] if size_key + l in obs.keys() else c
- for l, c in zip(layers_keys, counts_layers)
- ]
-
- counts_per_cell_sum = np.sum(counts_layers, 0)
- counts_per_cell_sum += counts_per_cell_sum == 0
-
- mean_abundances = [
- np.mean(counts_per_cell / counts_per_cell_sum)
- for counts_per_cell in counts_layers
- ]
+def show_proportions(adata, layers=None, use_raw=True):
+ warnings.warn(
+ "`scvelo.preprocessing.show_proportions` is deprecated since scVelo v0.2.4 "
+ "and will be removed in a future version. Please use "
+ "`scvelo.core.show_proportions` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- print(f"Abundance of {layers_keys}: {np.round(mean_abundances, 2)}")
+ _show_proportions(adata=adata, layers=layers, use_raw=use_raw)
def verify_dtypes(adata):
- try:
- _ = adata[:, 0]
- except:
- uns = adata.uns
- adata.uns = {}
- try:
- _ = adata[:, 0]
- logg.warn(
- "Safely deleted unstructured annotations (adata.uns), \n"
- "as these do not comply with permissible anndata datatypes."
- )
- except:
- logg.warn(
- "The data might be corrupted. Please verify all annotation datatypes."
- )
- adata.uns = uns
-
-
-def cleanup(data, clean="layers", keep=None, copy=False):
- """Deletes attributes not needed.
-
- Arguments
- ---------
- data: :class:`~anndata.AnnData`
- Annotated data matrix.
- clean: `str` or list of `str` (default: `layers`)
- Which attributes to consider for freeing memory.
- keep: `str` or list of `str` (default: None)
- Which attributes to keep.
- copy: `bool` (default: `False`)
- Return a copy instead of writing to adata.
-
- Returns
- -------
- Returns or updates `adata` with selection of attributes kept.
- """
- adata = data.copy() if copy else data
- verify_dtypes(adata)
-
- keep = list([keep] if isinstance(keep, str) else {} if keep is None else keep)
- keep.extend(["spliced", "unspliced", "Ms", "Mu", "clusters", "neighbors"])
+ warnings.warn(
+ "`scvelo.preprocessing.utils.verify_dtypes` is deprecated since scVelo v0.2.4 "
+ "and will be removed in a future version. Please use "
+ "`scvelo.core._anndata.verify_dtypes` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- ann_dict = {
- "obs": adata.obs_keys(),
- "var": adata.var_keys(),
- "uns": adata.uns_keys(),
- "layers": list(adata.layers.keys()),
- }
+ return _verify_dtypes(adata=adata)
- if "all" not in clean:
- ann_dict = {ann: values for (ann, values) in ann_dict.items() if ann in clean}
- for (ann, values) in ann_dict.items():
- for value in values:
- if value not in keep:
- del getattr(adata, ann)[value]
+def cleanup(data, clean="layers", keep=None, copy=False):
+ warnings.warn(
+ "`scvelo.preprocessing.cleanup` is deprecated since scVelo v0.2.4 and will be "
+ "removed in a future version. Please use `scvelo.core.cleanup` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- return adata if copy else None
+ return _cleanup(data=data, clean=clean, keep=keep, copy=copy)
def get_size(adata, layer=None):
- X = adata.X if layer is None else adata.layers[layer]
- return sum_var(X)
+ warnings.warn(
+ "`scvelo.preprocessing.utils.get_size` is deprecated since scVelo v0.2.4 and "
+ "will be removed in a future version. Please use `scvelo.core.get_size` "
+ "instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return _get_size(adata=adata, layer=layer)
def set_initial_size(adata, layers=None):
- if layers is None:
- layers = ["spliced", "unspliced"]
- verify_dtypes(adata)
- layers = [
- layer
- for layer in layers
- if layer in adata.layers.keys()
- and f"initial_size_{layer}" not in adata.obs.keys()
- ]
- for layer in layers:
- adata.obs[f"initial_size_{layer}"] = get_size(adata, layer)
- if "initial_size" not in adata.obs.keys():
- adata.obs["initial_size"] = get_size(adata)
+ warnings.warn(
+ "`scvelo.preprocessing.utils.set_initial_size` is deprecated since scVelo "
+ "v0.2.4 and will be removed in a future version. Please use "
+ "`scvelo.core.set_initial_size` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return _set_initial_size(adata=adata, layers=layers)
def get_initial_size(adata, layer=None, by_total_size=None):
- if by_total_size:
- sizes = [
- adata.obs[f"initial_size_{layer}"]
- for layer in {"spliced", "unspliced"}
- if f"initial_size_{layer}" in adata.obs.keys()
- ]
- return np.sum(sizes, axis=0)
- elif layer in adata.layers.keys():
- return (
- np.array(adata.obs[f"initial_size_{layer}"])
- if f"initial_size_{layer}" in adata.obs.keys()
- else get_size(adata, layer)
- )
- elif layer is None or layer == "X":
- return (
- np.array(adata.obs["initial_size"])
- if "initial_size" in adata.obs.keys()
- else get_size(adata)
- )
- else:
- return None
+ warnings.warn(
+ "`scvelo.preprocessing.get_initial_size` is deprecated since scVelo v0.2.4 and "
+ "will be removed in a future version. Please use "
+ "`scvelo.core.get_initial_size` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return _get_initial_size(adata=adata, layer=layer, by_total_size=by_total_size)
def _filter(X, min_counts=None, min_cells=None, max_counts=None, max_cells=None):
counts = (
- sum_obs(X)
+ sum(X, axis=0)
if (min_counts is not None or max_counts is not None)
- else sum_obs(X > 0)
+ else sum(X > 0, axis=0)
)
lb = (
min_counts
@@ -241,10 +193,11 @@ def filter_genes(
-------
Filters the object and adds `n_counts` to `adata.var`.
"""
+
adata = data.copy() if copy else data
# set initial cell sizes before filtering
- set_initial_size(adata)
+ _set_initial_size(adata)
layers = [
layer for layer in ["spliced", "unspliced"] if layer in adata.layers.keys()
@@ -429,8 +382,9 @@ def filter_genes_dispersion(
If an AnnData `adata` is passed, returns or updates `adata` depending on \
`copy`. It filters the `adata` and adds the annotations
"""
+
adata = data.copy() if copy else data
- set_initial_size(adata)
+ _set_initial_size(adata)
mean, var = materialize_as_ndarray(get_mean_var(adata.X))
@@ -558,13 +512,13 @@ def csr_vcorrcoef(X, y):
def counts_per_cell_quantile(X, max_proportion_per_cell=0.05, counts_per_cell=None):
if counts_per_cell is None:
- counts_per_cell = sum_var(X)
+ counts_per_cell = sum(X, axis=1)
gene_subset = np.all(
X <= counts_per_cell[:, None] * max_proportion_per_cell, axis=0
)
if issparse(X):
gene_subset = gene_subset.A1
- return sum_var(X[:, gene_subset])
+ return sum(X[:, gene_subset], axis=1)
def not_yet_normalized(X):
@@ -621,6 +575,7 @@ def normalize_per_cell(
-------
Returns or updates `adata` with normalized counts.
"""
+
adata = data.copy() if copy else data
if layers is None:
layers = ["spliced", "unspliced"]
@@ -633,7 +588,7 @@ def normalize_per_cell(
if isinstance(counts_per_cell, str):
if counts_per_cell not in adata.obs.keys():
- set_initial_size(adata, layers)
+ _set_initial_size(adata, layers)
counts_per_cell = (
adata.obs[counts_per_cell].values
if counts_per_cell in adata.obs.keys()
@@ -648,9 +603,9 @@ def normalize_per_cell(
counts = (
counts_per_cell
if counts_per_cell is not None
- else get_initial_size(adata, layer)
+ else _get_initial_size(adata, layer)
if use_initial_size
- else get_size(adata, layer)
+ else _get_size(adata, layer)
)
if max_proportion_per_cell is not None and (
0 < max_proportion_per_cell < 1
@@ -681,7 +636,7 @@ def normalize_per_cell(
adata.var["gene_count_corr"] = np.round(
csr_vcorrcoef(X.T, np.ravel((X > 0).sum(1))), 4
)
- except:
+ except Exception:
pass
else:
logg.warn(
@@ -689,7 +644,7 @@ def normalize_per_cell(
"To enforce normalization, set `enforce=True`."
)
- adata.obs["n_counts" if key_n_counts is None else key_n_counts] = get_size(adata)
+ adata.obs["n_counts" if key_n_counts is None else key_n_counts] = _get_size(adata)
if len(modified_layers) > 0:
logg.info("Normalized count data:", f"{', '.join(modified_layers)}.")
@@ -705,10 +660,12 @@ def log1p(data, copy=False):
Annotated data matrix.
copy: `bool` (default: `False`)
Return a copy of `adata` instead of updating it.
+
Returns
-------
Returns or updates `adata` depending on `copy`.
"""
+
adata = data.copy() if copy else data
X = (
(adata.X.data if issparse(adata.X) else adata.X)
@@ -793,6 +750,7 @@ def filter_and_normalize(
-------
Returns or updates `adata` depending on `copy`.
"""
+
adata = data.copy() if copy else data
if "spliced" not in adata.layers.keys() or "unspliced" not in adata.layers.keys():
diff --git a/scvelo/read_load.py b/scvelo/read_load.py
index 53b615e1..a883321e 100644
--- a/scvelo/read_load.py
+++ b/scvelo/read_load.py
@@ -1,15 +1,16 @@
-from . import logging as logg
-from .preprocessing.utils import set_initial_size
+import os
+import warnings
+from pathlib import Path
+from urllib.request import urlretrieve
-import os, re
import numpy as np
import pandas as pd
-from pandas.api.types import is_categorical_dtype
-from urllib.request import urlretrieve
-from pathlib import Path
-from scipy.sparse import issparse
-from anndata import AnnData
-from scanpy import read, read_loom
+
+from scvelo.core import clean_obs_names as _clean_obs_names
+from scvelo.core import get_df as _get_df
+from scvelo.core import merge as _merge
+from scvelo.core._anndata import obs_df as _obs_df
+from scvelo.core._anndata import var_df as _var_df
def load(filename, backup_url=None, header="infer", index_col="infer", **kwargs):
@@ -50,7 +51,8 @@ def load(filename, backup_url=None, header="infer", index_col="infer", **kwargs)
else:
raise ValueError(
f"'{filename}' does not end on a valid extension.\n"
- f"Please, provide one of the available extensions.\n{numpy_ext | pandas_ext}\n"
+ "Please, provide one of the available extensions.\n"
+ f"{numpy_ext | pandas_ext}\n"
)
@@ -58,194 +60,50 @@ def load(filename, backup_url=None, header="infer", index_col="infer", **kwargs)
def clean_obs_names(data, base="[AGTCBDHKMNRSVWY]", ID_length=12, copy=False):
- """Clean up the obs_names.
-
- For example an obs_name 'sample1_AGTCdate' is changed to 'AGTC' of the sample
- 'sample1_date'. The sample name is then saved in obs['sample_batch'].
- The genetic codes are identified according to according to
- https://www.neb.com/tools-and-resources/usage-guidelines/the-genetic-code.
-
- Arguments
- ---------
- adata: :class:`~anndata.AnnData`
- Annotated data matrix.
- base: `str` (default: `[AGTCBDHKMNRSVWY]`)
- Genetic code letters to be identified.
- ID_length: `int` (default: 12)
- Length of the Genetic Codes in the samples.
- copy: `bool` (default: `False`)
- Return a copy instead of writing to adata.
-
- Returns
- -------
- Returns or updates `adata` with the attributes
- obs_names: list
- updated names of the observations
- sample_batch: `.obs`
- names of the identified sample batches
- """
-
- def get_base_list(name, base):
- base_list = base
- while re.search(base_list + base, name) is not None:
- base_list += base
- if len(base_list) == 0:
- raise ValueError("Encountered an invalid ID in obs_names: ", name)
- return base_list
-
- adata = data.copy() if copy else data
-
- names = adata.obs_names
- base_list = get_base_list(names[0], base)
-
- if len(np.unique([len(name) for name in adata.obs_names])) == 1:
- start, end = re.search(base_list, names[0]).span()
- newIDs = [name[start:end] for name in names]
- start, end = 0, len(newIDs[0])
- for i in range(end - ID_length):
- if np.any([ID[i] not in base for ID in newIDs]):
- start += 1
- if np.any([ID[::-1][i] not in base for ID in newIDs]):
- end -= 1
-
- newIDs = [ID[start:end] for ID in newIDs]
- prefixes = [names[i].replace(newIDs[i], "") for i in range(len(names))]
- else:
- prefixes, newIDs = [], []
- for name in names:
- match = re.search(base_list, name)
- newID = (
- re.search(get_base_list(name, base), name).group()
- if match is None
- else match.group()
- )
- newIDs.append(newID)
- prefixes.append(name.replace(newID, ""))
-
- adata.obs_names = newIDs
- if len(prefixes[0]) > 0 and len(np.unique(prefixes)) > 1:
- adata.obs["sample_batch"] = (
- pd.Categorical(prefixes)
- if len(np.unique(prefixes)) < adata.n_obs
- else prefixes
- )
+ warnings.warn(
+ "`scvelo.read_load.clean_obs_names` is deprecated since scVelo v0.2.4 and will "
+ "be removed in a future version. Please use `scvelo.core.clean_obs_names` "
+ "instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- adata.obs_names_make_unique()
- return adata if copy else None
+ return _clean_obs_names(data=data, base=base, ID_length=ID_length, copy=copy)
def merge(adata, ldata, copy=True):
- """Merges two annotated data matrices.
-
- Arguments
- ---------
- adata: :class:`~anndata.AnnData`
- Annotated data matrix (reference data set).
- ldata: :class:`~anndata.AnnData`
- Annotated data matrix (to be merged into adata).
-
- Returns
- -------
- Returns a :class:`~anndata.AnnData` object
- """
- adata.var_names_make_unique()
- ldata.var_names_make_unique()
-
- if (
- "spliced" in ldata.layers.keys()
- and "initial_size_spliced" not in ldata.obs.keys()
- ):
- set_initial_size(ldata)
- elif (
- "spliced" in adata.layers.keys()
- and "initial_size_spliced" not in adata.obs.keys()
- ):
- set_initial_size(adata)
-
- common_obs = pd.unique(adata.obs_names.intersection(ldata.obs_names))
- common_vars = pd.unique(adata.var_names.intersection(ldata.var_names))
-
- if len(common_obs) == 0:
- clean_obs_names(adata)
- clean_obs_names(ldata)
- common_obs = adata.obs_names.intersection(ldata.obs_names)
-
- if copy:
- _adata = adata[common_obs].copy()
- _ldata = ldata[common_obs].copy()
- else:
- adata._inplace_subset_obs(common_obs)
- _adata, _ldata = adata, ldata[common_obs].copy()
-
- _adata.var_names_make_unique()
- _ldata.var_names_make_unique()
-
- same_vars = len(_adata.var_names) == len(_ldata.var_names) and np.all(
- _adata.var_names == _ldata.var_names
+ warnings.warn(
+ "`scvelo.read_load.merge` is deprecated since scVelo v0.2.4 and will be "
+ "removed in a future version. Please use `scvelo.core.merge` instead.",
+ DeprecationWarning,
+ stacklevel=2,
)
- join_vars = len(common_vars) > 0
-
- if join_vars and not same_vars:
- _adata._inplace_subset_var(common_vars)
- _ldata._inplace_subset_var(common_vars)
-
- for attr in _ldata.obs.keys():
- if attr not in _adata.obs.keys():
- _adata.obs[attr] = _ldata.obs[attr]
- for attr in _ldata.obsm.keys():
- if attr not in _adata.obsm.keys():
- _adata.obsm[attr] = _ldata.obsm[attr]
- for attr in _ldata.uns.keys():
- if attr not in _adata.uns.keys():
- _adata.uns[attr] = _ldata.uns[attr]
- if join_vars:
- for attr in _ldata.layers.keys():
- if attr not in _adata.layers.keys():
- _adata.layers[attr] = _ldata.layers[attr]
-
- if _adata.shape[1] == _ldata.shape[1]:
- same_vars = len(_adata.var_names) == len(_ldata.var_names) and np.all(
- _adata.var_names == _ldata.var_names
- )
- if same_vars:
- for attr in _ldata.var.keys():
- if attr not in _adata.var.keys():
- _adata.var[attr] = _ldata.var[attr]
- for attr in _ldata.varm.keys():
- if attr not in _adata.varm.keys():
- _adata.varm[attr] = _ldata.varm[attr]
- else:
- raise ValueError("Variable names are not identical.")
- return _adata if copy else None
+ return _merge(adata=adata, ldata=ldata, copy=True)
def obs_df(adata, keys, layer=None):
- lookup_keys = [k for k in keys if k in adata.var_names]
- if len(lookup_keys) < len(keys):
- logg.warn(
- f"Keys {[k for k in keys if k not in adata.var_names]} "
- f"were not found in `adata.var_names`."
- )
+ warnings.warn(
+ "`scvelo.read_load.obs_df` is deprecated since scVelo v0.2.4 and will be "
+ "removed in a future version. Please use `scvelo.core._anndata.obs_df` "
+ "instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- df = pd.DataFrame(index=adata.obs_names)
- for l in lookup_keys:
- df[l] = adata.obs_vector(l, layer=layer)
- return df
+ return _obs_df(adata=adata, keys=keys, layer=layer)
def var_df(adata, keys, layer=None):
- lookup_keys = [k for k in keys if k in adata.obs_names]
- if len(lookup_keys) < len(keys):
- logg.warn(
- f"Keys {[k for k in keys if k not in adata.obs_names]} "
- f"were not found in `adata.obs_names`."
- )
+ warnings.warn(
+ "`scvelo.read_load.var_df` is deprecated since scVelo v0.2.4 and will be "
+ "removed in a future version. Please use `scvelo.core._anndata.var_df` "
+ "instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- df = pd.DataFrame(index=adata.var_names)
- for l in lookup_keys:
- df[l] = adata.var_vector(l, layer=layer)
- return df
+ return _var_df(adata=adata, keys=keys, layer=layer)
def get_df(
@@ -258,161 +116,23 @@ def get_df(
dropna="all",
precision=None,
):
- """Get dataframe for a specified adata key.
-
- Return values for specified key
- (in obs, var, obsm, varm, obsp, varp, uns, or layers) as a dataframe.
-
- Arguments
- ------
- adata
- AnnData object or a numpy array to get values from.
- keys
- Keys from `.var_names`, `.obs_names`, `.var`, `.obs`,
- `.obsm`, `.varm`, `.obsp`, `.varp`, `.uns`, or `.layers`.
- layer
- Layer of `adata` to use as expression values.
- index
- List to set as index.
- columns
- List to set as columns names.
- sort_values
- Wether to sort values by first column (sort_values=True) or a specified column.
- dropna
- Drop columns/rows that contain NaNs in all ('all') or in any entry ('any').
- precision
- Set precision for pandas dataframe.
-
- Returns
- -------
- A dataframe.
- """
- if precision is not None:
- pd.set_option("precision", precision)
-
- if isinstance(data, AnnData):
- keys, keys_split = (
- keys.split("*") if isinstance(keys, str) and "*" in keys else (keys, None)
- )
- keys, key_add = (
- keys.split("/") if isinstance(keys, str) and "/" in keys else (keys, None)
- )
- keys = [keys] if isinstance(keys, str) else keys
- key = keys[0]
-
- s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"]
- d_keys = [
- data.obs.keys(),
- data.var.keys(),
- data.obsm.keys(),
- data.varm.keys(),
- data.uns.keys(),
- data.layers.keys(),
- ]
-
- if hasattr(data, "obsp") and hasattr(data, "varp"):
- s_keys.extend(["obsp", "varp"])
- d_keys.extend([data.obsp.keys(), data.varp.keys()])
-
- if keys is None:
- df = data.to_df()
- elif key in data.var_names:
- df = obs_df(data, keys, layer=layer)
- elif key in data.obs_names:
- df = var_df(data, keys, layer=layer)
- else:
- if keys_split is not None:
- keys = [
- k
- for k in list(data.obs.keys()) + list(data.var.keys())
- if key in k and keys_split in k
- ]
- key = keys[0]
- s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
- if len(s_key) == 0:
- raise ValueError(f"'{key}' not found in any of {', '.join(s_keys)}.")
- if len(s_key) > 1:
- logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.")
-
- s_key = s_key[-1]
- df = getattr(data, s_key)[keys if len(keys) > 1 else key]
- if key_add is not None:
- df = df[key_add]
- if index is None:
- index = (
- data.var_names
- if s_key == "varm"
- else data.obs_names
- if s_key in {"obsm", "layers"}
- else None
- )
- if index is None and s_key == "uns" and hasattr(df, "shape"):
- key_cats = np.array(
- [
- key
- for key in data.obs.keys()
- if is_categorical_dtype(data.obs[key])
- ]
- )
- num_cats = [
- len(data.obs[key].cat.categories) == df.shape[0]
- for key in key_cats
- ]
- if np.sum(num_cats) == 1:
- index = data.obs[key_cats[num_cats][0]].cat.categories
- if (
- columns is None
- and len(df.shape) > 1
- and df.shape[0] == df.shape[1]
- ):
- columns = index
- elif isinstance(index, str) and index in data.obs.keys():
- index = pd.Categorical(data.obs[index]).categories
- if columns is None and s_key == "layers":
- columns = data.var_names
- elif isinstance(columns, str) and columns in data.obs.keys():
- columns = pd.Categorical(data.obs[columns]).categories
- elif isinstance(data, pd.DataFrame):
- if isinstance(keys, str) and "*" in keys:
- keys, keys_split = keys.split("*")
- keys = [k for k in data.columns if keys in k and keys_split in k]
- df = data[keys] if keys is not None else data
- else:
- df = data
-
- if issparse(df):
- df = np.array(df.A)
- if columns is None and hasattr(df, "names"):
- columns = df.names
-
- df = pd.DataFrame(df, index=index, columns=columns)
-
- if dropna:
- df.replace("", np.nan, inplace=True)
- how = dropna if isinstance(dropna, str) else "any" if dropna is True else "all"
- df.dropna(how=how, axis=0, inplace=True)
- df.dropna(how=how, axis=1, inplace=True)
-
- if sort_values:
- sort_by = (
- sort_values
- if isinstance(sort_values, str) and sort_values in df.columns
- else df.columns[0]
- )
- df = df.sort_values(by=sort_by, ascending=False)
-
- if hasattr(data, "var_names"):
- if df.index[0] in data.var_names:
- df.var_names = df.index
- elif df.columns[0] in data.var_names:
- df.var_names = df.columns
- if hasattr(data, "obs_names"):
- if df.index[0] in data.obs_names:
- df.obs_names = df.index
- elif df.columns[0] in data.obs_names:
- df.obs_names = df.columns
+ warnings.warn(
+ "`scvelo.read_load.get_df` is deprecated since scVelo v0.2.4 and will be "
+ "removed in a future version. Please use `scvelo.core.get_df` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- return df
+ return _get_df(
+ data=data,
+ keys=keys,
+ layer=layer,
+ index=index,
+ columns=columns,
+ sort_values=sort_values,
+ dropna=dropna,
+ precision=precision,
+ )
DataFrame = get_df
@@ -435,6 +155,7 @@ def load_biomart():
df2.index = df2.pop("ensembl")
df = pd.concat([df, df2])
+ df = df.drop_duplicates()
return df
@@ -443,7 +164,7 @@ def convert_to_gene_names(ensembl_names=None):
df = load_biomart()
if ensembl_names is not None:
if isinstance(ensembl_names, str):
- ensembl_names = ensembl_names
+ ensembl_names = [ensembl_names]
valid_names = [name for name in ensembl_names if name in df.index]
if len(valid_names) > 0:
df = df.loc[valid_names]
diff --git a/scvelo/settings.py b/scvelo/settings.py
index e56621f4..28097145 100644
--- a/scvelo/settings.py
+++ b/scvelo/settings.py
@@ -1,3 +1,11 @@
+import builtins
+import warnings
+
+from cycler import cycler
+from packaging.version import parse
+
+from matplotlib import cbook, cm, colors, rcParams
+
"""Settings
"""
@@ -80,10 +88,6 @@
# Functions
# --------------------------------------------------------------------------------
-from matplotlib import rcParams, cm, colors, cbook
-from cycler import cycler
-import warnings
-
warnings.filterwarnings("ignore", category=cbook.mplDeprecation)
@@ -292,15 +296,8 @@ def set_figure_params(
Only concerns the notebook/IPython environment; see
`IPython.core.display.set_matplotlib_formats` for more details.
"""
- try:
- import IPython
-
- if isinstance(ipython_format, str):
- ipython_format = [ipython_format]
- IPython.display.set_matplotlib_formats(*ipython_format)
- except:
- pass
-
+ if ipython_format is not None:
+ _set_ipython(ipython_format)
global _rcParams_style
_rcParams_style = style
global _vector_friendly
@@ -332,6 +329,29 @@ def set_rcParams_defaults():
rcParams.update(rcParamsDefault)
+def _set_ipython(ipython_format="png2x"):
+ if getattr(builtins, "__IPYTHON__", None):
+ try:
+ import IPython
+
+ if isinstance(ipython_format, str):
+ ipython_format = [ipython_format]
+ if parse(IPython.__version__) < parse("7.23"):
+ IPython.display.set_matplotlib_formats(*ipython_format)
+ else:
+ from matplotlib_inline.backend_inline import set_matplotlib_formats
+
+ set_matplotlib_formats(*ipython_format)
+ except ImportError:
+ pass
+
+
+def _set_start_time():
+ from time import time
+
+ return time()
+
+
# ------------------------------------------------------------------------------
# Private global variables & functions
# ------------------------------------------------------------------------------
@@ -343,13 +363,6 @@ def set_rcParams_defaults():
_low_resolution_warning = True
"""Print warning when saving a figure with low resolution."""
-
-def _set_start_time():
- from time import time
-
- return time()
-
-
_start = _set_start_time()
"""Time when the settings module is first imported."""
diff --git a/scvelo/tl.py b/scvelo/tl.py
index 1c69ddff..5d71ecb5 100644
--- a/scvelo/tl.py
+++ b/scvelo/tl.py
@@ -1 +1 @@
-from scvelo.tools import *
+from scvelo.tools import * # noqa
diff --git a/scvelo/tools/__init__.py b/scvelo/tools/__init__.py
index 21d7091a..f46a862a 100644
--- a/scvelo/tools/__init__.py
+++ b/scvelo/tools/__init__.py
@@ -1,14 +1,51 @@
-from .velocity import velocity, velocity_genes
-from .velocity_graph import velocity_graph
+from scanpy.tools import diffmap, dpt, louvain, tsne, umap
+
+from .dynamical_model import (
+ align_dynamics,
+ differential_kinetic_test,
+ DynamicsRecovery,
+ latent_time,
+ rank_dynamical_genes,
+ recover_dynamics,
+ recover_latent_time,
+)
+from .paga import paga
+from .rank_velocity_genes import rank_velocity_genes, velocity_clusters
+from .score_genes_cell_cycle import score_genes_cell_cycle
+from .terminal_states import eigs, terminal_states
from .transition_matrix import transition_matrix
-from .velocity_embedding import velocity_embedding
+from .velocity import velocity, velocity_genes
from .velocity_confidence import velocity_confidence, velocity_confidence_transition
-from .terminal_states import eigs, terminal_states
-from .rank_velocity_genes import velocity_clusters, rank_velocity_genes
+from .velocity_embedding import velocity_embedding
+from .velocity_graph import velocity_graph
from .velocity_pseudotime import velocity_map, velocity_pseudotime
-from .dynamical_model import DynamicsRecovery, recover_dynamics, align_dynamics
-from .dynamical_model import recover_latent_time, latent_time
-from .dynamical_model import differential_kinetic_test, rank_dynamical_genes
-from scanpy.tools import tsne, umap, diffmap, dpt, louvain
-from .score_genes_cell_cycle import score_genes_cell_cycle
-from .paga import paga
+
+__all__ = [
+ "align_dynamics",
+ "differential_kinetic_test",
+ "diffmap",
+ "dpt",
+ "DynamicsRecovery",
+ "eigs",
+ "latent_time",
+ "louvain",
+ "paga",
+ "rank_dynamical_genes",
+ "rank_velocity_genes",
+ "recover_dynamics",
+ "recover_latent_time",
+ "score_genes_cell_cycle",
+ "terminal_states",
+ "transition_matrix",
+ "tsne",
+ "umap",
+ "velocity",
+ "velocity_clusters",
+ "velocity_confidence",
+ "velocity_confidence_transition",
+ "velocity_embedding",
+ "velocity_genes",
+ "velocity_graph",
+ "velocity_map",
+ "velocity_pseudotime",
+]
diff --git a/scvelo/tools/dynamical_model.py b/scvelo/tools/dynamical_model.py
index b44aa7b7..cf24fe06 100644
--- a/scvelo/tools/dynamical_model.py
+++ b/scvelo/tools/dynamical_model.py
@@ -1,22 +1,18 @@
-from .. import settings
-from .. import logging as logg
-from ..preprocessing.moments import get_connectivities
-from .utils import make_unique_list, test_bimodality
-from .dynamical_model_utils import BaseDynamics, linreg, convolve, tau_inv, unspliced
-
-from typing import Any, Union, Callable, Optional, Sequence
-from threading import Thread
-from multiprocessing import Manager
-
-from joblib import Parallel, delayed
-from scipy.sparse import issparse, spmatrix
-
import os
+
import numpy as np
import pandas as pd
+from scipy.optimize import minimize
+
import matplotlib.pyplot as pl
from matplotlib import rcParams
-from scipy.optimize import minimize
+
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.core import get_n_jobs, parallelize
+from scvelo.preprocessing.moments import get_connectivities
+from .dynamical_model_utils import BaseDynamics, convolve, linreg, tau_inv, unspliced
+from .utils import make_unique_list, test_bimodality
class DynamicsRecovery(BaseDynamics):
@@ -67,11 +63,11 @@ def initialize(self):
# initialize switching from u quantiles and alpha from s quantiles
try:
- tstat_u, pval_u, means_u = test_bimodality(u_w, kde=True)
- tstat_s, pval_s, means_s = test_bimodality(s_w, kde=True)
- except:
+ _, pval_u, means_u = test_bimodality(u_w, kde=True)
+ _, pval_s, means_s = test_bimodality(s_w, kde=True)
+ except Exception:
logg.warn("skipping bimodality check for", self.gene)
- tstat_u, tstat_s, pval_u, pval_s = 0, 0, 1, 1
+ _, _, pval_u, pval_s = 0, 0, 1, 1
means_u, means_s = [0, 0], [0, 0]
self.pval_steady = max(pval_u, pval_s)
@@ -376,8 +372,10 @@ def recover_dynamics(
---------
data: :class:`~anndata.AnnData`
Annotated data matrix.
- var_names: `str`, list of `str` (default: `'velocity_genes`)
- Names of variables/genes to use for the fitting.
+ var_names: `str`, list of `str` (default: `'velocity_genes'`)
+ Names of variables/genes to use for the fitting. If `var_names='velocity_genes'`
+ but there is no column `'velocity_genes'` in `adata.var`, velocity genes are
+ estimated using the steady state model.
n_top_genes: `int` or `None` (default: `None`)
Number of top velocity genes to use for the dynamical model.
max_iter:`int` (default: `10`)
@@ -415,12 +413,27 @@ def recover_dynamics(
n_jobs: `int` or `None` (default: `None`)
Number of parallel jobs.
backend: `str` (default: "loky")
- Backend used for multiprocessing. See :class:`joblib.Parallel` for valid options.
+ Backend used for multiprocessing. See :class:`joblib.Parallel` for valid
+ options.
Returns
-------
- Returns or updates `adata`
- """
+ fit_alpha: `.var`
+ inferred transcription rates
+ fit_beta: `.var`
+ inferred splicing rates
+ fit_gamma: `.var`
+ inferred degradation rates
+ fit_t_: `.var`
+ inferred switching time points
+ fit_scaling: `.var`
+ internal variance scaling factor for un/spliced counts
+ fit_likelihood: `.var`
+ likelihood of model fit
+ fit_alignment_scaling: `.var`
+ scaling factor to align gene-wise latent times to a universal latent time
+ """ # noqa E501
+
adata = data.copy() if copy else data
n_jobs = get_n_jobs(n_jobs=n_jobs)
@@ -485,7 +498,7 @@ def recover_dynamics(
conn = get_connectivities(adata) if fit_connected_states else None
- res = _parallelize(
+ res = parallelize(
_fit_recovery,
var_names,
n_jobs,
@@ -549,14 +562,17 @@ def recover_dynamics(
if L: # is False if only one invalid / irrecoverable gene was given in var_names
cur_len = adata.varm["loss"].shape[1] if "loss" in adata.varm.keys() else 2
- max_len = max(np.max([len(l) for l in L]), cur_len) if L else cur_len
+ max_len = max(np.max([len(loss) for loss in L]), cur_len) if L else cur_len
loss = np.ones((adata.n_vars, max_len)) * np.nan
if "loss" in adata.varm.keys():
loss[:, :cur_len] = adata.varm["loss"]
loss[idx] = np.vstack(
- [np.concatenate([l, np.ones(max_len - len(l)) * np.nan]) for l in L]
+ [
+ np.concatenate([loss, np.ones(max_len - len(loss)) * np.nan])
+ for loss in L
+ ]
)
adata.varm["loss"] = loss
@@ -617,9 +633,9 @@ def align_dynamics(
Whether to remove outliers.
copy: `bool` (default: `False`)
Return a copy instead of writing to `adata`.
- Returns
+
+ Returns
-------
- Returns or updates `adata` with the attributes
alpha, beta, gamma, t_, alignment_scaling: `.var`
aligned parameters
fit_t, fit_tau, fit_tau_: `.layer`
@@ -748,17 +764,18 @@ def latent_time(
If not set, a overall transcriptional timescale of 20 hours is used as prior.
copy: `bool` (default: `False`)
Return a copy instead of writing to `adata`.
- Returns
+
+ Returns
-------
- Returns or updates `adata` with the attributes
latent_time: `.obs`
latent time from learned dynamics for each cell
- """
+ """ # noqa E501
+
adata = data.copy() if copy else data
- from .utils import vcorrcoef, scale
- from .dynamical_model_utils import root_time, compute_shared_time
+ from .dynamical_model_utils import compute_shared_time, root_time
from .terminal_states import terminal_states
+ from .utils import scale, vcorrcoef
from .velocity_graph import velocity_graph
from .velocity_pseudotime import velocity_pseudotime
@@ -793,7 +810,7 @@ def latent_time(
idx_roots[pd.isnull(idx_roots)] = 0
if np.any([isinstance(ix, str) for ix in idx_roots]):
idx_roots = np.array([isinstance(ix, str) for ix in idx_roots], dtype=int)
- idx_roots = idx_roots.astype(np.float) > 1 - 1e-3
+ idx_roots = idx_roots.astype(float) > 1 - 1e-3
if np.sum(idx_roots) > 0:
roots = roots[idx_roots]
else:
@@ -811,7 +828,7 @@ def latent_time(
idx_fates[pd.isnull(idx_fates)] = 0
if np.any([isinstance(ix, str) for ix in idx_fates]):
idx_fates = np.array([isinstance(ix, str) for ix in idx_fates], dtype=int)
- idx_fates = idx_fates.astype(np.float) > 1 - 1e-3
+ idx_fates = idx_fates.astype(float) > 1 - 1e-3
if np.sum(idx_fates) > 0:
fates = fates[idx_fates]
else:
@@ -919,8 +936,12 @@ def differential_kinetic_test(
Returns
-------
- Returns or updates `adata`
- """
+ fit_pvals_kinetics: `.varm`
+ P-values of competing kinetic for each group and gene
+ fit_diff_kinetics: `.var`
+ Groups that have differential kinetics for each gene.
+ """ # noqa E501
+
adata = data.copy() if copy else data
if "Ms" not in adata.layers.keys() or "Mu" not in adata.layers.keys():
@@ -981,7 +1002,7 @@ def differential_kinetic_test(
if "fit_diff_kinetics" in adata.var.keys():
diff_kinetics = np.array(adata.var["fit_diff_kinetics"])
else:
- diff_kinetics = np.empty(adata.n_vars, dtype="|U16")
+ diff_kinetics = np.empty(adata.n_vars, dtype="object")
idx = []
progress = logg.ProgressReporter(len(var_names))
@@ -1014,7 +1035,7 @@ def differential_kinetic_test(
"added \n"
f" '{add_key}_diff_kinetics', "
f"clusters displaying differential kinetics (adata.var)\n"
- f" '{add_key}_pval_kinetics', "
+ f" '{add_key}_pvals_kinetics', "
f"p-values of differential kinetics (adata.var)"
)
@@ -1043,11 +1064,11 @@ def rank_dynamical_genes(data, n_genes=100, groupby=None, copy=False):
Returns
-------
- Returns or updates `data` with the attributes
rank_dynamical_genes : `.uns`
Structured array to be indexed by group id storing the gene
names. Ordered according to scores.
"""
+
from .dynamical_model_utils import get_divergence
adata = data.copy() if copy else data
@@ -1152,166 +1173,5 @@ def _fit_recovery(
return idx, dms
-# -*- coding: utf-8 -*-
-"""Module used to parallelize model fitting."""
-
-_msg_shown = False
-
-
-def get_n_jobs(n_jobs):
- if n_jobs is None or (n_jobs < 0 and os.cpu_count() + 1 + n_jobs <= 0):
- return 1
- elif n_jobs > os.cpu_count():
- return os.cpu_count()
- elif n_jobs < 0:
- return os.cpu_count() + 1 + n_jobs
- else:
- return n_jobs
-
-
-def _parallelize(
- callback: Callable[[Any], Any],
- collection: Union[spmatrix, Sequence[Any]],
- n_jobs: Optional[int] = None,
- n_split: Optional[int] = None,
- unit: str = "",
- as_array: bool = True,
- use_ixs: bool = False,
- backend: str = "loky",
- extractor: Optional[Callable[[Any], Any]] = None,
- show_progress_bar: bool = True,
-) -> Union[np.ndarray, Any]:
- """
- Parallelize function call over a collection of elements.
-
- Parameters
- ----------
- callback
- Function to parallelize.
- collection
- Sequence of items which to chunkify.
- n_jobs
- Number of parallel jobs.
- n_split
- Split :paramref:`collection` into :paramref:`n_split` chunks.
- If `None`, split into :paramref:`n_jobs` chunks.
- unit
- Unit of the progress bar.
- as_array
- Whether to convert the results not :class:`numpy.ndarray`.
- use_ixs
- Whether to pass indices to the callback.
- backend
- Which backend to use for multiprocessing. See :class:`joblib.Parallel` for valid options.
- extractor
- Function to apply to the result after all jobs have finished.
- show_progress_bar
- Whether to show a progress bar.
-
- Returns
- -------
- :class:`numpy.ndarray`
- Result depending on :paramref:`extractor` and :paramref:`as_array`.
- """
-
- if show_progress_bar:
- try:
- try:
- from tqdm.notebook import tqdm
- except ImportError:
- from tqdm import tqdm_notebook as tqdm
- import ipywidgets # noqa
- except ImportError:
- global _msg_shown
- tqdm = None
-
- if not _msg_shown:
- logg.warn(
- "Unable to create progress bar. "
- "Consider installing `tqdm` as `pip install tqdm` "
- "and `ipywidgets` as `pip install ipywidgets`,\n"
- "or disable the progress bar using `show_progress_bar=False`."
- )
- _msg_shown = True
- else:
- tqdm = None
-
- def update(pbar, queue, n_total):
- n_finished = 0
- while n_finished < n_total:
- try:
- res = queue.get()
- except EOFError as e:
- if not n_finished != n_total:
- raise RuntimeError(
- f"Finished only `{n_finished} out of `{n_total}` tasks.`"
- ) from e
- break
- assert res in (None, (1, None), 1) # (None, 1) means only 1 job
- if res == (1, None):
- n_finished += 1
- if pbar is not None:
- pbar.update()
- elif res is None:
- n_finished += 1
- elif pbar is not None:
- pbar.update()
-
- if pbar is not None:
- pbar.close()
-
- def wrapper(*args, **kwargs):
- if pass_queue and show_progress_bar:
- pbar = None if tqdm is None else tqdm(total=col_len, unit=unit)
- queue = Manager().Queue()
- thread = Thread(target=update, args=(pbar, queue, len(collections)))
- thread.start()
- else:
- pbar, queue, thread = None, None, None
-
- res = Parallel(n_jobs=n_jobs, backend=backend)(
- delayed(callback)(
- *((i, cs) if use_ixs else (cs,)),
- *args,
- **kwargs,
- queue=queue,
- )
- for i, cs in enumerate(collections)
- )
-
- res = np.array(res) if as_array else res
- if thread is not None:
- thread.join()
-
- return res if extractor is None else extractor(res)
-
- col_len = collection.shape[0] if issparse(collection) else len(collection)
-
- if n_split is None:
- n_split = get_n_jobs(n_jobs=n_jobs)
-
- if issparse(collection):
- if n_split == collection.shape[0]:
- collections = [collection[[ix], :] for ix in range(collection.shape[0])]
- else:
- step = collection.shape[0] // n_split
-
- ixs = [
- np.arange(i * step, min((i + 1) * step, collection.shape[0]))
- for i in range(n_split)
- ]
- ixs[-1] = np.append(
- ixs[-1], np.arange(ixs[-1][-1] + 1, collection.shape[0])
- )
-
- collections = [collection[ix, :] for ix in filter(len, ixs)]
- else:
- collections = list(filter(len, np.array_split(collection, n_split)))
-
- pass_queue = not hasattr(callback, "py_func") # we'd be inside a numba function
-
- return wrapper
-
-
def _flatten(iterable):
return [i for it in iterable for i in it]
diff --git a/scvelo/tools/dynamical_model_utils.py b/scvelo/tools/dynamical_model_utils.py
index e1340a75..8d929d55 100644
--- a/scvelo/tools/dynamical_model_utils.py
+++ b/scvelo/tools/dynamical_model_utils.py
@@ -1,17 +1,19 @@
-from .. import logging as logg
-from ..preprocessing.moments import get_connectivities
-from .utils import make_dense, round
+import warnings
+
+import numpy as np
+import pandas as pd
+from scipy.sparse import issparse
+from scipy.stats.distributions import chi2, norm
import matplotlib as mpl
+import matplotlib.gridspec as gridspec
import matplotlib.pyplot as pl
from matplotlib import rcParams
-import matplotlib.gridspec as gridspec
-from scipy.stats.distributions import chi2, norm
-from scipy.sparse import issparse
-import warnings
-import pandas as pd
-import numpy as np
+from scvelo import logging as logg
+from scvelo.core import clipped_log, invert, SplicingDynamics
+from scvelo.preprocessing.moments import get_connectivities
+from .utils import make_dense, round
exp = np.exp
@@ -20,14 +22,28 @@
def log(x, eps=1e-6): # to avoid invalid values for log.
- return np.log(np.clip(x, eps, 1 - eps))
+
+ warnings.warn(
+ "`clipped_log` is deprecated since scVelo v0.2.4 and will be removed in a "
+ "future version. Please use `clipped_log(x, eps=1e-6)` from `scvelo/core/`"
+ "instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return clipped_log(x, lb=0, ub=1, eps=1e-6)
def inv(x):
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- x_inv = 1 / x * (x != 0)
- return x_inv
+
+ warnings.warn(
+ "`inv` is deprecated since scVelo v0.2.4 and will be removed in a future "
+ "version. Please use `invert(x)` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return invert(x)
def normalize(X, axis=0, min_confidence=None):
@@ -112,17 +128,23 @@ def unspliced(tau, u0, alpha, beta):
def spliced(tau, s0, u0, alpha, beta, gamma):
- c = (alpha - u0 * beta) * inv(gamma - beta)
+ c = (alpha - u0 * beta) * invert(gamma - beta)
expu, exps = exp(-beta * tau), exp(-gamma * tau)
return s0 * exps + alpha / gamma * (1 - exps) + c * (exps - expu)
def mRNA(tau, u0, s0, alpha, beta, gamma):
- expu, exps = exp(-beta * tau), exp(-gamma * tau)
- expus = (alpha - u0 * beta) * inv(gamma - beta) * (exps - expu)
- u = u0 * expu + alpha / beta * (1 - expu)
- s = s0 * exps + alpha / gamma * (1 - exps) + expus
- return u, s
+
+ warnings.warn(
+ "`mRNA` is deprecated since scVelo v0.2.4 and will be removed in a future "
+ "version. Please use `SplicingDynamics` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(
+ tau, stacked=False
+ )
def adjust_increments(tau, tau_=None):
@@ -166,13 +188,15 @@ def tau_inv(u, s=None, u0=None, s0=None, alpha=None, beta=None, gamma=None):
any_invus = np.any(inv_us) and s is not None
if any_invus: # tau_inv(u, s)
- beta_ = beta * inv(gamma - beta)
+ beta_ = beta * invert(gamma - beta)
xinf = alpha / gamma - beta_ * (alpha / beta)
- tau = -1 / gamma * log((s - beta_ * u - xinf) / (s0 - beta_ * u0 - xinf))
+ tau = (
+ -1 / gamma * clipped_log((s - beta_ * u - xinf) / (s0 - beta_ * u0 - xinf))
+ )
if any_invu: # tau_inv(u)
uinf = alpha / beta
- tau_u = -1 / beta * log((u - uinf) / (u0 - uinf))
+ tau_u = -1 / beta * clipped_log((u - uinf) / (u0 - uinf))
tau = tau_u * inv_u + tau * inv_us if any_invus else tau_u
return tau
@@ -189,9 +213,10 @@ def assign_tau(
num = np.clip(int(len(u) / 5), 200, 500)
tpoints = np.linspace(0, t_, num=num)
tpoints_ = np.linspace(0, t0, num=num)[1:]
-
- xt = np.vstack(mRNA(tpoints, 0, 0, alpha, beta, gamma)).T
- xt_ = np.vstack(mRNA(tpoints_, u0_, s0_, 0, beta, gamma)).T
+ xt = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(tpoints)
+ xt_ = SplicingDynamics(
+ alpha=0, beta=beta, gamma=gamma, initial_state=[u0_, s0_]
+ ).get_solution(tpoints_)
# assign time points (oth. projection onto 'on' and 'off' curve)
tau = tpoints[
@@ -253,7 +278,9 @@ def compute_divergence(
"""
# set tau, tau_
if u0_ is None or s0_ is None:
- u0_, s0_ = mRNA(t_, 0, 0, alpha, beta, gamma)
+ u0_, s0_ = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(
+ t_, stacked=False
+ )
if tau is None or tau_ is None or t_ is None:
tau, tau_, t_ = assign_tau(
u, s, alpha, beta, gamma, t_, u0_, s0_, assignment_mode
@@ -263,8 +290,12 @@ def compute_divergence(
# adjust increments of tau, tau_ to avoid meaningless jumps
if constraint_time_increments:
- ut, st = mRNA(tau, 0, 0, alpha, beta, gamma)
- ut_, st_ = mRNA(tau_, u0_, s0_, 0, beta, gamma)
+ ut, st = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(
+ tau, stacked=False
+ )
+ ut_, st_ = SplicingDynamics(
+ alpha=0, beta=beta, gamma=gamma, initial_state=[u0_, s0_]
+ ).get_solution(tau_, stacked=False)
distu, distu_ = (u - ut) / std_u, (u - ut_) / std_u
dists, dists_ = (s - st) / std_s, (s - st_) / std_s
@@ -288,8 +319,12 @@ def compute_divergence(
tau_[off] = adjust_increments(tau_[off])
# compute induction/repression state distances
- ut, st = mRNA(tau, 0, 0, alpha, beta, gamma)
- ut_, st_ = mRNA(tau_, u0_, s0_, 0, beta, gamma)
+ ut, st = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(
+ tau, stacked=False
+ )
+ ut_, st_ = SplicingDynamics(
+ alpha=0, beta=beta, gamma=gamma, initial_state=[u0_, s0_]
+ ).get_solution(tau_, stacked=False)
if ut.ndim > 1 and ut.shape[1] == 1:
ut = np.ravel(ut)
@@ -542,7 +577,9 @@ def compute_divergence(
o = (res < 2) * (t < t_)
o_ = (res < 2) * (t >= t_)
tau, alpha, u0, s0 = vectorize(t, t_, alpha, beta, gamma)
- ut, st = mRNA(tau, u0, s0, alpha, beta, gamma)
+ ut, st = SplicingDynamics(
+ alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0]
+ ).get_solution(tau, stacked=False)
ut_, st_ = ut, st
ut = ut * o + ut_ * o_
@@ -578,7 +615,9 @@ def compute_divergence(
o = (res < 2) * (t < t_)
o_ = (res < 2) * (t >= t_)
tau, alpha, u0, s0 = vectorize(t, t_, alpha, beta, gamma)
- ut, st = mRNA(tau, u0, s0, alpha, beta, gamma)
+ ut, st = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(
+ tau, stacked=False
+ )
ut_, st_ = ut, st
alpha = alpha * o
@@ -644,7 +683,9 @@ def curve_dists(
num=None,
):
if u0_ is None or s0_ is None:
- u0_, s0_ = mRNA(t_, 0, 0, alpha, beta, gamma)
+ u0_, s0_ = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(
+ t_, stacked=False
+ )
x_obs = np.vstack([u, s]).T
std_x = np.vstack([std_u / scaling, std_s]).T
@@ -654,8 +695,12 @@ def curve_dists(
tpoints = np.linspace(0, t_, num=num)
tpoints_ = np.linspace(0, t0, num=num)[1:]
- curve_t = np.vstack(mRNA(tpoints, 0, 0, alpha, beta, gamma)).T
- curve_t_ = np.vstack(mRNA(tpoints_, u0_, s0_, 0, beta, gamma)).T
+ curve_t = SplicingDynamics(alpha=alpha, beta=beta, gamma=gamma).get_solution(
+ tpoints
+ )
+ curve_t_ = SplicingDynamics(
+ alpha=0, beta=beta, gamma=gamma, initial_state=[u0_, s0_]
+ ).get_solution(tpoints_)
# match each curve point to nearest observation
dist, dist_ = np.zeros(len(curve_t)), np.zeros(len(curve_t_))
@@ -728,7 +773,7 @@ def __init__(
self.recoverable = True
try:
self.initialize_weights()
- except:
+ except Exception:
self.recoverable = False
logg.warn(f"Model for {self.gene} could not be instantiated.")
@@ -796,7 +841,11 @@ def load_pars(self, adata, gene):
self.pval_steady = adata.var["fit_pval_steady"][idx]
self.alpha_ = 0
- self.u0_, self.s0_ = mRNA(self.t_, 0, 0, self.alpha, self.beta, self.gamma)
+ self.u0_, self.s0_ = SplicingDynamics(
+ alpha=self.alpha,
+ beta=self.beta,
+ gamma=self.gamma,
+ ).get_solution(self.t_, stacked=False)
self.pars = [self.alpha, self.beta, self.gamma, self.t_, self.scaling]
self.pars = np.array(self.pars)[:, None]
@@ -988,7 +1037,9 @@ def get_dists(
)
tau, alpha, u0, s0 = vectorize(t, t_, alpha, beta, gamma)
- ut, st = mRNA(tau, u0, s0, alpha, beta, gamma)
+ ut, st = SplicingDynamics(
+ alpha=alpha, beta=beta, gamma=gamma, initial_state=[u0, s0]
+ ).get_solution(tau, stacked=False)
udiff = np.array(ut - u) / self.std_u * scaling
sdiff = np.array(st - s) / self.std_s
@@ -1072,9 +1123,13 @@ def get_curve_likelihood(self):
kwargs = dict(std_u=self.std_u, std_s=self.std_s, scaling=scaling)
dist, dist_ = curve_dists(u, s, alpha, beta, gamma, t_, **kwargs)
- l = -0.5 / len(dist) * np.sum(dist) / varx - 0.5 * np.log(2 * np.pi * varx)
- l_ = -0.5 / len(dist_) * np.sum(dist_) / varx - 0.5 * np.log(2 * np.pi * varx)
- likelihood = np.exp(np.max([l, l_]))
+ log_likelihood = -0.5 / len(dist) * np.sum(dist) / varx - 0.5 * np.log(
+ 2 * np.pi * varx
+ )
+ log_likelihood_ = -0.5 / len(dist_) * np.sum(dist_) / varx - 0.5 * np.log(
+ 2 * np.pi * varx
+ )
+ likelihood = np.exp(np.max([log_likelihood, log_likelihood_]))
return likelihood
def get_variance(self, **kwargs):
@@ -1113,7 +1168,7 @@ def plot_phase(
show_assignments=None,
**kwargs,
):
- from ..plotting.scatter import scatter
+ from scvelo.plotting.scatter import scatter
if np.all([x is None for x in [alpha, beta, gamma, scaling, t_]]):
refit_time = False
@@ -1164,7 +1219,7 @@ def plot_profile_contour(
return_color_scale=False,
**kwargs,
):
- from ..plotting.utils import update_axes
+ from scvelo.plotting.utils import update_axes
x_var = getattr(self, xkey)
y_var = getattr(self, ykey)
@@ -1175,14 +1230,13 @@ def plot_profile_contour(
assignment_mode = self.assignment_mode
self.assignment_mode = None
- fp = lambda x, y: self.get_likelihood(
- **{xkey: x, ykey: y}, refit_time=refit_time
- )
-
+ # TODO: Check if list comprehension can be used
zp = np.zeros((len(x), len(x)))
for i, xi in enumerate(x):
for j, yi in enumerate(y):
- zp[i, j] = fp(xi, yi)
+ zp[i, j] = self.get_likelihood(
+ **{xkey: xi, ykey: yi}, refit_time=refit_time
+ )
log_zp = np.log1p(zp.T)
if vmin is None:
@@ -1238,7 +1292,7 @@ def plot_profile_hist(
vmax=None,
show=True,
):
- from ..plotting.utils import update_axes
+ from scvelo.plotting.utils import update_axes
x_var = getattr(self, xkey)
x = np.linspace(-sight, sight, num=num) * x_var + x_var
@@ -1246,10 +1300,10 @@ def plot_profile_hist(
assignment_mode = self.assignment_mode
self.assignment_mode = None
- fp = lambda x: self.get_likelihood(**{xkey: x}, refit_time=True)
+ # TODO: Check if list comprehension can be used
zp = np.zeros((len(x)))
for i, xi in enumerate(x):
- zp[i] = fp(xi)
+ zp[i] = self.get_likelihood(**{xkey: xi}, refit_time=True)
log_zp = np.log1p(zp.T)
if vmin is None:
@@ -1373,8 +1427,7 @@ def plot_state_likelihoods(
ax=None,
**kwargs,
):
- from ..plotting.utils import update_axes
- from ..plotting.utils import rgb_custom_colormap
+ from scvelo.plotting.utils import rgb_custom_colormap, update_axes
if color_map is None:
color_map = rgb_custom_colormap(
@@ -1589,6 +1642,7 @@ def get_pval_diff_kinetics(self, orth_beta=None, min_cells=10, **kwargs):
-------
p-value
"""
+
if (
"weights_cluster" in kwargs
and np.sum(kwargs["weights_cluster"]) < min_cells
diff --git a/scvelo/tools/dynamical_model_utils_deprecated.py b/scvelo/tools/dynamical_model_utils_deprecated.py
deleted file mode 100644
index 083d7354..00000000
--- a/scvelo/tools/dynamical_model_utils_deprecated.py
+++ /dev/null
@@ -1,1040 +0,0 @@
-# DEPRECATED
-
-from .. import settings
-from .. import logging as logg
-from ..preprocessing.moments import get_connectivities
-from .utils import make_dense, make_unique_list, test_bimodality
-
-import warnings
-import matplotlib.pyplot as pl
-from matplotlib import rcParams
-
-import numpy as np
-
-exp = np.exp
-
-
-def log(x, eps=1e-6): # to avoid invalid values for log.
- return np.log(np.clip(x, eps, 1 - eps))
-
-
-def inv(x):
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- x_inv = 1 / x * (x != 0)
- return x_inv
-
-
-def unspliced(tau, u0, alpha, beta):
- expu = exp(-beta * tau)
- return u0 * expu + alpha / beta * (1 - expu)
-
-
-def spliced(tau, s0, u0, alpha, beta, gamma):
- c = (alpha - u0 * beta) * inv(gamma - beta)
- expu, exps = exp(-beta * tau), exp(-gamma * tau)
- return s0 * exps + alpha / gamma * (1 - exps) + c * (exps - expu)
-
-
-def mRNA(tau, u0, s0, alpha, beta, gamma):
- expu, exps = exp(-beta * tau), exp(-gamma * tau)
- u = u0 * expu + alpha / beta * (1 - expu)
- s = (
- s0 * exps
- + alpha / gamma * (1 - exps)
- + (alpha - u0 * beta) * inv(gamma - beta) * (exps - expu)
- )
- return u, s
-
-
-def vectorize(t, t_, alpha, beta, gamma=None, alpha_=0, u0=0, s0=0, sorted=False):
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- o = np.array(t < t_, dtype=int)
- tau = t * o + (t - t_) * (1 - o)
-
- u0_ = unspliced(t_, u0, alpha, beta)
- s0_ = spliced(t_, s0, u0, alpha, beta, gamma if gamma is not None else beta / 2)
-
- # vectorize u0, s0 and alpha
- u0 = u0 * o + u0_ * (1 - o)
- s0 = s0 * o + s0_ * (1 - o)
- alpha = alpha * o + alpha_ * (1 - o)
-
- if sorted:
- idx = np.argsort(t)
- tau, alpha, u0, s0 = tau[idx], alpha[idx], u0[idx], s0[idx]
- return tau, alpha, u0, s0
-
-
-def tau_inv(u, s=None, u0=None, s0=None, alpha=None, beta=None, gamma=None):
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- inv_u = (gamma >= beta) if gamma is not None else True
- inv_us = np.invert(inv_u)
- any_invu = np.any(inv_u) or s is None
- any_invus = np.any(inv_us) and s is not None
-
- if any_invus: # tau_inv(u, s)
- beta_ = beta * inv(gamma - beta)
- xinf = alpha / gamma - beta_ * (alpha / beta)
- tau = -1 / gamma * log((s - beta_ * u - xinf) / (s0 - beta_ * u0 - xinf))
-
- if any_invu: # tau_inv(u)
- uinf = alpha / beta
- tau_u = -1 / beta * log((u - uinf) / (u0 - uinf))
- tau = tau_u * inv_u + tau * inv_us if any_invus else tau_u
- return tau
-
-
-def find_swichting_time(u, s, tau, o, alpha, beta, gamma, plot=False):
- off, on = o == 0, o == 1
- t0_ = np.max(tau[on]) if on.sum() > 0 and np.max(tau[on]) > 0 else np.max(tau)
-
- if off.sum() > 0:
- u_, s_, tau_ = u[off], s[off], tau[off]
-
- beta_ = beta * inv(gamma - beta)
- ceta_ = alpha / gamma - beta_ * alpha / beta
-
- x = -ceta_ * exp(-gamma * tau_)
- y = s_ - beta_ * u_
-
- exp_t0_ = (y * x).sum() / (x ** 2).sum()
- if -1 < exp_t0_ < 0:
- t0_ = -1 / gamma * log(exp_t0_ + 1)
- if plot:
- pl.scatter(x, y)
- return t0_
-
-
-def fit_alpha(u, s, tau, o, beta, gamma, fit_scaling=False):
- off, on = o == 0, o == 1
- if on.sum() > 0 or off.sum() > 0 or tau[on].min() == 0 or tau[off].min() == 0:
- alpha = None
- else:
- tau_on, tau_off = tau[on], tau[off]
-
- # 'on' state
- expu, exps = exp(-beta * tau_on), exp(-gamma * tau_on)
-
- # 'off' state
- t0_ = np.max(tau_on)
- expu_, exps_ = exp(-beta * tau_off), exp(-gamma * tau_off)
- expu0_, exps0_ = exp(-beta * t0_), exp(-gamma * t0_)
-
- # from unspliced dynamics
- c_beta = 1 / beta * (1 - expu)
- c_beta_ = 1 / beta * (1 - expu0_) * expu_
-
- # from spliced dynamics
- c_gamma = (1 - exps) / gamma + (exps - expu) * inv(gamma - beta)
- c_gamma_ = (
- (1 - exps0_) / gamma + (exps0_ - expu0_) * inv(gamma - beta)
- ) * exps_ - (1 - expu0_) * (exps_ - expu_) * inv(gamma - beta)
-
- # concatenating together
- c = np.concatenate([c_beta, c_gamma, c_beta_, c_gamma_]).T
- x = np.concatenate([u[on], s[on], u[off], s[off]]).T
-
- alpha = (c * x).sum() / (c ** 2).sum()
-
- if fit_scaling: # alternatively compute alpha and scaling simultaneously
- c = np.concatenate([c_gamma, c_gamma_]).T
- x = np.concatenate([s[on], s[off]]).T
- alpha = (c * x).sum() / (c ** 2).sum()
-
- c = np.concatenate([c_beta, c_beta_]).T
- x = np.concatenate([u[on], u[off]]).T
- scaling = (c * x).sum() / (c ** 2).sum() / alpha # ~ alpha * z / alpha
- return alpha, scaling
-
- return alpha
-
-
-def fit_scaling(u, t, t_, alpha, beta):
- tau, alpha, u0, _ = vectorize(t, t_, alpha, beta)
- ut = unspliced(tau, u0, alpha, beta)
- return (u * ut).sum() / (ut ** 2).sum()
-
-
-def tau_s(s, s0, u0, alpha, beta, gamma, u=None, tau=None, eps=1e-2):
- if tau is None:
- tau = tau_inv(u, u0=u0, alpha=alpha, beta=beta) if u is not None else 1
- tau_prev, loss, n_iter, max_iter, mixed_states = 1e6, 1e6, 0, 10, np.any(alpha == 0)
- b0 = (alpha - beta * u0) * inv(gamma - beta)
- g0 = s0 - alpha / gamma + b0
-
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
-
- while np.abs(tau - tau_prev).max() > eps and loss > eps and n_iter < max_iter:
- tau_prev, n_iter = tau, n_iter + 1
-
- expu, exps = b0 * exp(-beta * tau), g0 * exp(-gamma * tau)
- f = exps - expu + alpha / gamma # >0
- ft = -gamma * exps + beta * expu # >0 if on else <0
- ftt = gamma ** 2 * exps - beta ** 2 * expu
-
- a, b, c = ftt / 2, ft, f - s
- term = b ** 2 - 4 * a * c
- update = (-b + np.sqrt(term)) / (2 * a)
- if mixed_states:
- update = np.nan_to_num(update) * (alpha > 0) + (-c / b) * (alpha <= 0)
- tau = (
- np.nan_to_num(tau_prev + update) * (s != 0)
- if np.any(term > 0)
- else tau_prev / 10
- )
- loss = np.abs(
- alpha / gamma + g0 * exp(-gamma * tau) - b0 * exp(-beta * tau) - s
- ).max()
-
- return np.clip(tau, 0, None)
-
-
-def assign_timepoints_projection(
- u, s, alpha, beta, gamma, t0_=None, u0_=None, s0_=None, n_timepoints=300
-):
- if t0_ is None:
- t0_ = tau_inv(u=u0_, u0=0, alpha=alpha, beta=beta)
- if u0_ is None or s0_ is None:
- u0_, s0_ = (
- unspliced(t0_, 0, alpha, beta),
- spliced(t0_, 0, 0, alpha, beta, gamma),
- )
-
- tpoints = np.linspace(0, t0_, num=n_timepoints)
- tpoints_ = np.linspace(
- 0, tau_inv(np.min(u[s > 0]), u0=u0_, alpha=0, beta=beta), num=n_timepoints
- )[1:]
-
- xt = np.vstack(
- [unspliced(tpoints, 0, alpha, beta), spliced(tpoints, 0, 0, alpha, beta, gamma)]
- ).T
- xt_ = np.vstack(
- [unspliced(tpoints_, u0_, 0, beta), spliced(tpoints_, s0_, u0_, 0, beta, gamma)]
- ).T
- x_obs = np.vstack([u, s]).T
-
- # assign time points (oth. projection onto 'on' and 'off' curve)
- tau, o, diff = np.zeros(len(u)), np.zeros(len(u), dtype=int), np.zeros(len(u))
- tau_alt, diff_alt = np.zeros(len(u)), np.zeros(len(u))
- for i, xi in enumerate(x_obs):
- diffs, diffs_ = (
- np.linalg.norm((xt - xi), axis=1),
- np.linalg.norm((xt_ - xi), axis=1),
- )
- idx, idx_ = np.argmin(diffs), np.argmin(diffs_)
-
- o[i] = np.argmin([diffs_[idx_], diffs[idx]])
- tau[i] = [tpoints_[idx_], tpoints[idx]][o[i]]
- diff[i] = [diffs_[idx_], diffs[idx]][o[i]]
-
- tau_alt[i] = [tpoints_[idx_], tpoints[idx]][1 - o[i]]
- diff_alt[i] = [diffs_[idx_], diffs[idx]][1 - o[i]]
-
- t = tau * o + (t0_ + tau) * (1 - o)
-
- return t, tau, o
-
-
-"""State-independent derivatives"""
-
-
-def dtau(u, s, alpha, beta, gamma, u0, s0, du0=[0, 0, 0], ds0=[0, 0, 0, 0]):
- a, b, g, gb, b0 = alpha, beta, gamma, gamma - beta, beta * inv(gamma - beta)
-
- cu = s - a / g - b0 * (u - a / b)
- c0 = s0 - a / g - b0 * (u0 - a / b)
- cu += cu == 0
- c0 += c0 == 0
- cu_, c0_ = 1 / cu, 1 / c0
-
- dtau_a = b0 / g * (c0_ - cu_) + 1 / g * c0_ * (ds0[0] - b0 * du0[0])
- dtau_b = 1 / gb ** 2 * ((u - a / g) * cu_ - (u0 - a / g) * c0_)
-
- dtau_c = -a / g * (1 / g ** 2 - 1 / gb ** 2) * (cu_ - c0_) - b0 / g / gb * (
- u * cu_ - u0 * c0_
- ) # + 1/g**2 * np.log(cu/c0)
-
- return dtau_a, dtau_b, dtau_c
-
-
-def du(tau, alpha, beta, u0=0, du0=[0, 0, 0], dtau=[0, 0, 0]):
- # du0 is the derivative du0 / d(alpha, beta, tau)
- expu, cb = exp(-beta * tau), alpha / beta
- du_a = (
- du0[0] * expu + 1.0 / beta * (1 - expu) + (alpha - beta * u0) * dtau[0] * expu
- )
- du_b = (
- du0[1] * expu
- - cb / beta * (1 - expu)
- + (cb - u0) * tau * expu
- + (alpha - beta * u0) * dtau[1] * expu
- )
- return du_a, du_b
-
-
-def ds(
- tau, alpha, beta, gamma, u0=0, s0=0, du0=[0, 0, 0], ds0=[0, 0, 0, 0], dtau=[0, 0, 0]
-):
- # ds0 is the derivative ds0 / d(alpha, beta, gamma, tau)
- expu, exps = exp(-beta * tau), exp(-gamma * tau)
- expus = exps - expu
-
- cbu = (alpha - beta * u0) * inv(gamma - beta)
- ccu = (alpha - gamma * u0) * inv(gamma - beta)
- ccs = alpha / gamma - s0 - cbu
-
- ds_a = (
- ds0[0] * exps
- + 1.0 / gamma * (1 - exps)
- + 1 * inv(gamma - beta) * (1 - beta * du0[0]) * expus
- + (ccs * gamma * exps + cbu * beta * expu) * dtau[0]
- )
- ds_b = (
- ds0[1] * exps
- + cbu * tau * expu
- + 1 * inv(gamma - beta) * (ccu - beta * du0[1]) * expus
- + (ccs * gamma * exps + cbu * beta * expu) * dtau[1]
- )
- ds_c = (
- ds0[2] * exps
- + ccs * tau * exps
- - alpha / gamma ** 2 * (1 - exps)
- - cbu * inv(gamma - beta) * expus
- + (ccs * gamma * exps + cbu * beta * expu) * dtau[2]
- )
-
- return ds_a, ds_b, ds_c
-
-
-def derivatives(
- u, s, t, t0_, alpha, beta, gamma, scaling=1, alpha_=0, u0=0, s0=0, weights=None
-):
- o = np.array(t < t0_, dtype=int)
-
- du0 = np.array(du(t0_, alpha, beta, u0))[:, None] * (1 - o)[None, :]
- ds0 = np.array(ds(t0_, alpha, beta, gamma, u0, s0))[:, None] * (1 - o)[None, :]
-
- tau, alpha, u0, s0 = vectorize(t, t0_, alpha, beta, gamma, alpha_, u0, s0)
- dt = np.array(dtau(u, s, alpha, beta, gamma, u0, s0, du0, ds0))
-
- # state-dependent derivatives:
- du_a, du_b = du(tau, alpha, beta, u0, du0, dt)
- du_a, du_b = du_a * scaling, du_b * scaling
-
- ds_a, ds_b, ds_c = ds(tau, alpha, beta, gamma, u0, s0, du0, ds0, dt)
-
- # evaluate derivative of likelihood:
- ut, st = mRNA(tau, u0, s0, alpha, beta, gamma)
-
- # udiff = np.array(ut * scaling - u)
- udiff = np.array(ut - u / scaling)
- sdiff = np.array(st - s)
-
- if weights is not None:
- udiff = np.multiply(udiff, weights)
- sdiff = np.multiply(sdiff, weights)
-
- dl_a = (du_a * (1 - o)).dot(udiff) + (ds_a * (1 - o)).dot(sdiff)
- dl_a_ = (du_a * o).dot(udiff) + (ds_a * o).dot(sdiff)
-
- dl_b = du_b.dot(udiff) + ds_b.dot(sdiff)
- dl_c = ds_c.dot(sdiff)
-
- dl_tau, dl_t0_ = None, None
- return dl_a, dl_b, dl_c, dl_a_, dl_tau, dl_t0_
-
-
-class BaseDynamics:
- def __init__(self, adata=None, u=None, s=None):
- self.s, self.u = s, u
-
- zeros, zeros3 = np.zeros(adata.n_obs), np.zeros((3, 1))
- self.u0, self.s0, self.u0_, self.s0_, self.t_, self.scaling = (
- None,
- None,
- None,
- None,
- None,
- None,
- )
- self.t, self.tau, self.o, self.weights = zeros, zeros, zeros, zeros
-
- self.alpha, self.beta, self.gamma, self.alpha_, self.pars = (
- None,
- None,
- None,
- None,
- None,
- )
- self.dpars, self.m_dpars, self.v_dpars, self.loss = zeros3, zeros3, zeros3, []
-
- def uniform_weighting(self, n_regions=5, perc=95): # deprecated
- from numpy import union1d as union
- from numpy import intersect1d as intersect
-
- u, s = self.u, self.s
- u_b = np.linspace(0, np.percentile(u, perc), n_regions)
- s_b = np.linspace(0, np.percentile(s, perc), n_regions)
-
- regions, weights = {}, np.ones(len(u))
- for i in range(n_regions):
- if i == 0:
- region = intersect(np.where(u < u_b[i + 1]), np.where(s < s_b[i + 1]))
- elif i < n_regions - 1:
- lower_cut = union(np.where(u > u_b[i]), np.where(s > s_b[i]))
- upper_cut = intersect(
- np.where(u < u_b[i + 1]), np.where(s < s_b[i + 1])
- )
- region = intersect(lower_cut, upper_cut)
- else:
- region = union(
- np.where(u > u_b[i]), np.where(s > s_b[i])
- ) # lower_cut for last region
- regions[i] = region
- if len(region) > 0:
- weights[region] = n_regions / len(region)
- # set weights accordingly such that each region has an equal overall contribution.
- self.weights = weights * len(u) / np.sum(weights)
- self.u_b, self.s_b = u_b, s_b
-
- def plot_regions(self):
- u, s, ut, st = self.u, self.s, self.ut, self.st
- u_b, s_b = self.u_b, self.s_b
-
- pl.figure(dpi=100)
- pl.scatter(s, u, color="grey")
- pl.xlim(0)
- pl.ylim(0)
- pl.xlabel("spliced")
- pl.ylabel("unspliced")
-
- for i in range(len(s_b)):
- pl.plot([s_b[i], s_b[i], 0], [0, u_b[i], u_b[i]])
-
- def plot_derivatives(self):
- u, s = self.u, self.s
- alpha, beta, gamma = self.alpha, self.beta, self.gamma
- t, tau, o, t_ = self.t, self.tau, self.o, self.t_
-
- du0 = np.array(du(t_, alpha, beta))[:, None] * (1 - o)[None, :]
- ds0 = np.array(ds(t_, alpha, beta, gamma))[:, None] * (1 - o)[None, :]
-
- tau, alpha, u0, s0 = vectorize(t, t_, alpha, beta, gamma)
- dt = np.array(dtau(u, s, alpha, beta, gamma, u0, s0))
-
- du_a, du_b = du(tau, alpha, beta, u0=u0, du0=du0, dtau=dt)
- ds_a, ds_b, ds_c = ds(
- tau, alpha, beta, gamma, u0=u0, s0=s0, du0=du0, ds0=ds0, dtau=dt
- )
-
- idx = np.argsort(t)
- t = np.sort(t)
-
- pl.plot(t, du_a[idx], label=r"$\partial u / \partial\alpha$")
- pl.plot(t, 0.2 * du_b[idx], label=r"$\partial u / \partial \beta$")
- pl.plot(t, ds_a[idx], label=r"$\partial s / \partial \alpha$")
- pl.plot(t, ds_b[idx], label=r"$\partial s / \partial \beta$")
- pl.plot(t, 0.2 * ds_c[idx], label=r"$\partial s / \partial \gamma$")
-
- pl.legend()
- pl.xlabel("t")
-
-
-class DynamicsRecovery(BaseDynamics):
- def __init__(
- self,
- adata=None,
- gene=None,
- u=None,
- s=None,
- use_raw=False,
- load_pars=None,
- fit_scaling=False,
- fit_time=True,
- fit_switching=True,
- fit_steady_states=True,
- fit_alpha=True,
- fit_connected_states=True,
- ):
- super(DynamicsRecovery, self).__init__(adata.n_obs)
-
- _layers = adata[:, gene].layers
- self.gene = gene
- self.use_raw = use_raw = use_raw or "Ms" not in _layers.keys()
-
- # extract actual data
- if u is None or s is None:
- u = (
- make_dense(_layers["unspliced"])
- if use_raw
- else make_dense(_layers["Mu"])
- )
- s = make_dense(_layers["spliced"]) if use_raw else make_dense(_layers["Ms"])
- self.s, self.u = s, u
-
- # set weights for fitting (exclude dropouts and extreme outliers)
- nonzero = np.ravel(s > 0) & np.ravel(u > 0)
- s_filter = np.ravel(s < np.percentile(s[nonzero], 98))
- u_filter = np.ravel(u < np.percentile(u[nonzero], 98))
-
- self.weights = s_filter & u_filter & nonzero
- self.fit_scaling = fit_scaling
- self.fit_time = fit_time
- self.fit_alpha = fit_alpha
- self.fit_switching = fit_switching
- self.fit_steady_states = fit_steady_states
- self.connectivities = (
- get_connectivities(adata)
- if fit_connected_states is True
- else fit_connected_states
- )
-
- if load_pars and "fit_alpha" in adata.var.keys():
- self.load_pars(adata, gene)
- else:
- self.initialize()
-
- def initialize(self):
- # set weights
- u, s, w = self.u * 1.0, self.s * 1.0, self.weights
- u_w, s_w, perc = u[w], s[w], 98
-
- # initialize scaling
- self.std_u, self.std_s = np.std(u_w), np.std(s_w)
- scaling = (
- self.std_u / self.std_s
- if isinstance(self.fit_scaling, bool)
- else self.fit_scaling
- )
- u, u_w = u / scaling, u_w / scaling
-
- # initialize beta and gamma from extreme quantiles of s
- if True:
- weights_s = s_w >= np.percentile(s_w, perc, axis=0)
- else:
- us_norm = s_w / np.clip(np.max(s_w, axis=0), 1e-3, None) + u_w / np.clip(
- np.max(u_w, axis=0), 1e-3, None
- )
- weights_s = us_norm >= np.percentile(us_norm, perc, axis=0)
-
- beta, gamma = 1, linreg(convolve(u_w, weights_s), convolve(s_w, weights_s))
-
- u_inf, s_inf = u_w[weights_s].mean(), s_w[weights_s].mean()
- u0_, s0_ = u_inf, s_inf
- alpha = np.mean(
- [s_inf * gamma, u_inf * beta]
- ) # np.mean([s0_ * gamma, u0_ * beta])
-
- # initialize switching from u quantiles and alpha from s quantiles
- tstat_u, pval_u, means_u = test_bimodality(u_w, kde=True)
- tstat_s, pval_s, means_s = test_bimodality(s_w, kde=True)
- self.pval_steady = max(pval_u, pval_s)
- self.u_steady = means_u[1]
- self.s_steady = means_s[1]
-
- if self.pval_steady < 0.1:
- u_inf = np.mean([u_inf, self.u_steady])
- s_inf = np.mean([s_inf, self.s_steady])
- alpha = s_inf * gamma
- beta = alpha / u_inf
-
- weights_u = u_w >= np.percentile(u_w, perc, axis=0)
- u0_, s0_ = u_w[weights_u].mean(), s_w[weights_u].mean()
-
- # alpha, beta, gamma = np.array([alpha, beta, gamma]) * scaling
- t_ = tau_inv(u0_, s0_, 0, 0, alpha, beta, gamma)
-
- # update object with initialized vars
- alpha_, u0, s0 = 0, 0, 0
- self.alpha, self.beta, self.gamma, self.alpha_, self.scaling = (
- alpha,
- beta,
- gamma,
- alpha_,
- scaling,
- )
- self.u0, self.s0, self.u0_, self.s0_, self.t_ = u0, s0, u0_, s0_, t_
- self.pars = np.array([alpha, beta, gamma, self.t_, self.scaling])[:, None]
-
- # initialize time point assignment
- self.t, self.tau, self.o = self.get_time_assignment()
- self.loss = [self.get_loss()]
-
- self.update_scaling(sight=0.5)
- self.update_scaling(sight=0.1)
-
- def load_pars(self, adata, gene):
- idx = adata.var_names.get_loc(gene) if isinstance(gene, str) else gene
- self.alpha = adata.var["fit_alpha"][idx]
- self.beta = adata.var["fit_beta"][idx]
- self.gamma = adata.var["fit_gamma"][idx]
- self.scaling = adata.var["fit_scaling"][idx]
- self.t_ = adata.var["fit_t_"][idx]
- self.t = adata.layers["fit_t"][:, idx]
- self.o = self.t < self.t_
- self.tau = self.t * self.o + (self.t - self.t_) * (1 - self.o)
- self.pars = np.array(
- [self.alpha, self.beta, self.gamma, self.t_, self.scaling]
- )[:, None]
-
- self.u0, self.s0, self.alpha_ = 0, 0, 0
- self.u0_ = unspliced(self.t_, self.u0, self.alpha, self.beta)
- self.s0_ = spliced(self.t_, self.u0, self.s0, self.alpha, self.beta, self.gamma)
-
- self.update_state_dependent()
-
- def fit(
- self,
- max_iter=100,
- r=None,
- method=None,
- clip_loss=None,
- assignment_mode=None,
- min_loss=True,
- ):
- updated, idx_update = True, np.clip(int(max_iter / 10), 1, None)
-
- for i in range(max_iter):
- self.update_vars(r=r, method=method, clip_loss=clip_loss)
- if updated or (i % idx_update == 1) or i == max_iter - 1:
- updated = self.update_state_dependent()
- if i > 10 and (i % idx_update == 1):
- loss_prev, loss = np.max(self.loss[-10:]), self.loss[-1]
- if loss_prev - loss < loss_prev * 1e-3:
- updated = self.shuffle_pars()
- if not updated:
- break
-
- if self.fit_switching:
- self.update_switching()
- if min_loss:
- alpha, beta, gamma, t_, scaling = self.pars[:, np.argmin(self.loss)]
- up = self.update_loss(
- None, t_, alpha, beta, gamma, scaling, reassign_time=True
- )
- self.t, self.tau, self.o = self.get_time_assignment(
- assignment_mode=assignment_mode
- )
-
- def update_state_dependent(self):
- updated = False
- if self.fit_alpha:
- updated = self.update_alpha() | updated
- if self.fit_switching:
- updated = self.update_switching() | updated
- return updated
-
- def update_scaling(self, sight=0.5): # fit scaling and update if improved
- z_vals = self.scaling + np.linspace(-1, 1, num=5) * self.scaling * sight
- for z in z_vals:
- u0_ = self.u0_ * self.scaling / z
- beta = self.beta / self.scaling * z
- self.update_loss(
- scaling=z, beta=beta, u0_=u0_, s0_=self.s0_, reassign_time=True
- )
-
- def update_alpha(self): # fit alpha (generalized lin.reg), update if improved
- updated = False
- alpha = self.get_optimal_alpha()
- gamma = self.gamma
-
- alpha_vals = alpha + np.linspace(-1, 1, num=5) * alpha / 30
- gamma_vals = gamma + np.linspace(-1, 1, num=4) * gamma / 30
-
- for alpha in alpha_vals:
- for gamma in gamma_vals:
- updated = (
- self.update_loss(alpha=alpha, gamma=gamma, reassign_time=True)
- | updated
- )
- return updated
-
- def update_switching(
- self,
- ): # find optimal switching (generalized lin.reg) & assign timepoints/states (explicit)
- updated = False
- # t_ = self.t_
- t_ = self.get_optimal_switch()
- t_vals = t_ + np.linspace(-1, 1, num=3) * t_ / 5
- for t_ in t_vals:
- updated = self.update_loss(t_=t_, reassign_time=True) | updated
-
- if True: # self.pval_steady > .1:
- z_vals = 1 + np.linspace(-1, 1, num=4) / 5
- for z in z_vals:
- beta, gamma = self.beta * z, self.gamma * z
- t, tau, o = self.get_time_assignment(beta=beta, gamma=gamma)
- t_ = np.max(t * o)
- if t_ > 0:
- update = self.update_loss(
- t_=np.max(t * o), beta=beta, gamma=gamma, reassign_time=True
- )
- updated |= update
- if update:
- self.update_loss(
- t_=self.get_optimal_switch(), reassign_time=True
- )
- return updated
-
- def update_vars(self, r=None, method=None, clip_loss=None):
- if r is None:
- r = 1e-2 if method == "adam" else 1e-5
- if clip_loss is None:
- clip_loss = method != "adam"
- # if self.weights is None:
- # self.uniform_weighting(n_regions=5, perc=95)
- t, t_, alpha, beta, gamma, scaling = (
- self.t,
- self.t_,
- self.alpha,
- self.beta,
- self.gamma,
- self.scaling,
- )
- dalpha, dbeta, dgamma, dalpha_, dtau, dt_ = derivatives(
- self.u, self.s, t, t_, alpha, beta, gamma, scaling
- )
-
- if method == "adam":
- b1, b2, eps = 0.9, 0.999, 1e-8
-
- # update 1st and 2nd order gradient moments
- dpars = np.array([dalpha, dbeta, dgamma])
- m_dpars = b1 * self.m_dpars[:, -1] + (1 - b1) * dpars
- v_dpars = b2 * self.v_dpars[:, -1] + (1 - b2) * dpars ** 2
-
- self.dpars = np.c_[self.dpars, dpars]
- self.m_dpars = np.c_[self.m_dpars, m_dpars]
- self.v_dpars = np.c_[self.v_dpars, v_dpars]
-
- # correct for bias
- t = len(self.m_dpars[0])
- m_dpars /= 1 - b1 ** t
- v_dpars /= 1 - b2 ** t
-
- # Adam parameter update
- # Parameters are restricted to be positive
- n_alpha = alpha - r * m_dpars[0] / (np.sqrt(v_dpars[0]) + eps)
- alpha = n_alpha if n_alpha > 0 else alpha
- n_beta = beta - r * m_dpars[1] / (np.sqrt(v_dpars[1]) + eps)
- beta = n_beta if n_beta > 0 else beta
- n_gamma = gamma - r * m_dpars[2] / (np.sqrt(v_dpars[2]) + eps)
- gamma = n_gamma if n_gamma > 0 else gamma
-
- else:
- # Parameters are restricted to be positive
- n_alpha = alpha - r * dalpha
- alpha = n_alpha if n_alpha > 0 else alpha
- n_beta = beta - r * dbeta
- beta = n_beta if n_beta > 0 else beta
- n_gamma = gamma - r * dgamma
- gamma = n_gamma if n_gamma > 0 else gamma
-
- # tau -= r * dtau
- # t_ -= r * dt_
- # t_ = np.max(self.tau * self.o)
- # t = tau * self.o + (tau + t_) * (1 - self.o)
-
- updated_vars = self.update_loss(
- alpha=alpha,
- beta=beta,
- gamma=gamma,
- clip_loss=clip_loss,
- reassign_time=False,
- )
-
- def update_loss(
- self,
- t=None,
- t_=None,
- alpha=None,
- beta=None,
- gamma=None,
- scaling=None,
- u0_=None,
- s0_=None,
- reassign_time=False,
- clip_loss=True,
- report=False,
- ):
- vals = [t_, alpha, beta, gamma, scaling]
- vals_prev = [self.t_, self.alpha, self.beta, self.gamma, self.scaling]
- vals_name = ["t_", "alpha", "beta", "gamma", "scaling"]
- new_vals, new_vals_prev, new_vals_name = [], [], []
- loss_prev = self.loss[-1] if len(self.loss) > 0 else 1e6
-
- for val, val_prev, val_name in zip(vals, vals_prev, vals_name):
- if val is not None:
- new_vals.append(val)
- new_vals_prev.append(val_prev)
- new_vals_name.append(val_name)
- if t_ is None:
- t_ = (
- tau_inv(
- self.u0_ if u0_ is None else u0_,
- self.s0_ if s0_ is None else s0_,
- 0,
- 0,
- self.alpha if alpha is None else alpha,
- self.beta if beta is None else beta,
- self.gamma if gamma is None else gamma,
- )
- if u0_ is not None
- else self.t_
- )
-
- t, t_, alpha, beta, gamma, scaling = self.get_vals(
- t, t_, alpha, beta, gamma, scaling
- )
-
- if reassign_time:
- # t_ = self.get_optimal_switch(alpha, beta, gamma) if t_ is None else t_
- t, tau, o = self.get_time_assignment(t_, alpha, beta, gamma, scaling)
-
- loss = self.get_loss(t, t_, alpha, beta, gamma, scaling)
- perform_update = not clip_loss or loss < loss_prev
-
- if perform_update:
- if (
- len(self.loss) > 0 and loss_prev - loss > loss_prev * 0.01 and report
- ): # improvement by at least 1%
- print(
- "Update:",
- " ".join(map(str, new_vals_name)),
- " ".join(map(str, np.round(new_vals_prev, 2))),
- "-->",
- " ".join(map(str, np.round(new_vals, 2))),
- )
-
- print(" loss:", np.round(loss_prev, 2), "-->", np.round(loss, 2))
-
- if "t_" in new_vals_name or reassign_time:
- if reassign_time:
- self.t = t
- self.t_ = t_
- self.o = o = np.array(self.t < t_, dtype=bool)
- self.tau = self.t * o + (self.t - t_) * (1 - o)
-
- if u0_ is not None:
- self.u0_ = u0_
- self.s0_ = s0_
-
- if "alpha" in new_vals_name:
- self.alpha = alpha
- if "beta" in new_vals_name:
- self.beta = beta
- if "gamma" in new_vals_name:
- self.gamma = gamma
- if "scaling" in new_vals_name:
- self.scaling = scaling
- # self.rescale_invariant()
-
- self.pars = np.c_[
- self.pars,
- np.array([self.alpha, self.beta, self.gamma, self.t_, self.scaling])[
- :, None
- ],
- ]
- self.loss.append(loss if perform_update else loss_prev)
-
- return perform_update
-
- def rescale_invariant(self, z=None):
- z = self.scaling / self.std_u * self.std_s if z is None else z
- self.alpha, self.beta, self.gamma = (
- np.array([self.alpha, self.beta, self.gamma]) * z
- )
- self.t, self.tau, self.t_ = self.t / z, self.tau / z, self.t_ / z
-
- def shuffle_pars(self, alpha_sight=[-0.5, 0.5], gamma_sight=[-0.5, 0.5], num=5):
- alpha_vals = (
- np.linspace(alpha_sight[0], alpha_sight[1], num=num) * self.alpha
- + self.alpha
- )
- gamma_vals = (
- np.linspace(gamma_sight[0], gamma_sight[1], num=num) * self.gamma
- + self.gamma
- )
-
- x, y = alpha_vals, gamma_vals
- f = lambda x, y: self.get_loss(alpha=x, gamma=y, reassign_time=self.fit_time)
- z = np.zeros((len(x), len(x)))
-
- for i, xi in enumerate(x):
- for j, yi in enumerate(y):
- z[i, j] = f(xi, yi)
- ix, iy = np.unravel_index(z.argmin(), z.shape)
- return self.update_loss(alpha=x[ix], gamma=y[ix], reassign_time=self.fit_time)
-
-
-def read_pars(adata, pars_names=["alpha", "beta", "gamma", "t_", "scaling"], key="fit"):
- pars = []
- for name in pars_names:
- pkey = f"{key}_{name}"
- par = (
- adata.var[pkey].values
- if pkey in adata.var.keys()
- else np.zeros(adata.n_vars) * np.nan
- )
- pars.append(par)
- return pars
-
-
-def write_pars(
- adata, pars, pars_names=["alpha", "beta", "gamma", "t_", "scaling"], add_key="fit"
-):
- for i, name in enumerate(pars_names):
- adata.var[f"{add_key}_{name}"] = pars[i]
-
-
-def recover_dynamics_deprecated(
- data,
- var_names="velocity_genes",
- max_iter=10,
- learning_rate=None,
- t_max=None,
- use_raw=False,
- fit_scaling=True,
- fit_time=True,
- fit_switching=True,
- fit_steady_states=True,
- fit_alpha=True,
- fit_connected_states=True,
- min_loss=True,
- assignment_mode=None,
- load_pars=None,
- add_key="fit",
- method="adam",
- return_model=True,
- plot_results=False,
- copy=False,
- **kwargs,
-):
- """Estimates velocities in a gene-specific manner
-
- Arguments
- ---------
- data: :class:`~anndata.AnnData`
- Annotated data matrix.
-
- Returns
- -------
- Returns or updates `adata`
- """
- adata = data.copy() if copy else data
- logg.info("recovering dynamics", r=True)
-
- if isinstance(var_names, str) and var_names not in adata.var_names:
- var_names = (
- adata.var_names[adata.var[var_names] == True]
- if "genes" in var_names and var_names in adata.var.keys()
- else adata.var_names
- if "all" in var_names or "genes" in var_names
- else var_names
- )
- var_names = make_unique_list(var_names, allow_array=True)
- var_names = [name for name in var_names if name in adata.var_names]
- if len(var_names) == 0:
- raise ValueError("Variable name not found in var keys.")
-
- if fit_connected_states:
- fit_connected_states = get_connectivities(adata)
-
- alpha, beta, gamma, t_, scaling = read_pars(adata)
- idx = []
- L, P, T = (
- [],
- [],
- adata.layers["fit_t"]
- if "fit_t" in adata.layers.keys()
- else np.zeros(adata.shape) * np.nan,
- )
-
- progress = logg.ProgressReporter(len(var_names))
- for i, gene in enumerate(var_names):
- dm = DynamicsRecovery(
- adata,
- gene,
- use_raw=use_raw,
- load_pars=load_pars,
- fit_time=fit_time,
- fit_alpha=fit_alpha,
- fit_switching=fit_switching,
- fit_scaling=fit_scaling,
- fit_steady_states=fit_steady_states,
- fit_connected_states=fit_connected_states,
- )
- if max_iter > 1:
- dm.fit(
- max_iter,
- learning_rate,
- assignment_mode=assignment_mode,
- min_loss=min_loss,
- method=method,
- **kwargs,
- )
-
- ix = adata.var_names.get_loc(gene)
- idx.append(ix)
-
- alpha[ix], beta[ix], gamma[ix], t_[ix], scaling[ix] = dm.pars[:, -1]
- T[:, ix] = dm.t
- L.append(dm.loss)
- if plot_results and i < 4:
- P.append(dm.pars)
-
- progress.update()
- progress.finish()
-
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- T_max = np.nanpercentile(T, 95, axis=0) - np.nanpercentile(T, 5, axis=0)
- m = t_max / T_max if t_max is not None else np.ones(adata.n_vars)
- alpha, beta, gamma, T, t_ = alpha / m, beta / m, gamma / m, T * m, t_ * m
-
- write_pars(adata, [alpha, beta, gamma, t_, scaling])
- adata.layers["fit_t"] = T
-
- cur_len = adata.varm["loss"].shape[1] if "loss" in adata.varm.keys() else 2
- max_len = max(np.max([len(l) for l in L]), cur_len)
- loss = np.ones((adata.n_vars, max_len)) * np.nan
-
- if "loss" in adata.varm.keys():
- loss[:, :cur_len] = adata.varm["loss"]
-
- loss[idx] = np.vstack(
- [np.concatenate([l, np.ones(max_len - len(l)) * np.nan]) for l in L]
- )
- adata.varm["loss"] = loss
-
- logg.info(" finished", time=True, end=" " if settings.verbosity > 2 else "\n")
- logg.hint(
- "added \n"
- f" '{add_key}_pars', fitted parameters for splicing dynamics (adata.var)"
- )
-
- if plot_results: # Plot Parameter Stats
- n_rows, n_cols = len(var_names[:4]), 6
- figsize = [2 * n_cols, 1.5 * n_rows] # rcParams['figure.figsize']
- fontsize = rcParams["font.size"]
- fig, axes = pl.subplots(nrows=n_rows, ncols=6, figsize=figsize)
- pl.subplots_adjust(wspace=0.7, hspace=0.5)
- for i, gene in enumerate(var_names[:4]):
- P[i] *= np.array(
- [1 / m[idx[i]], 1 / m[idx[i]], 1 / m[idx[i]], m[idx[i]], 1]
- )[:, None]
- ax = axes[i] if n_rows > 1 else axes
- for j, pij in enumerate(P[i]):
- ax[j].plot(pij)
- ax[len(P[i])].plot(L[i])
- if i == 0:
- for j, name in enumerate(
- ["alpha", "beta", "gamma", "t_", "scaling", "loss"]
- ):
- ax[j].set_title(name, fontsize=fontsize)
-
- return dm if return_model else adata if copy else None
diff --git a/scvelo/tools/optimization.py b/scvelo/tools/optimization.py
index b178fb97..86bcf72e 100644
--- a/scvelo/tools/optimization.py
+++ b/scvelo/tools/optimization.py
@@ -1,8 +1,11 @@
-from .utils import sum_obs, prod_sum_obs, make_dense
+import warnings
+
+import numpy as np
from scipy.optimize import minimize
from scipy.sparse import csr_matrix, issparse
-import numpy as np
-import warnings
+
+from scvelo.core import prod_sum, sum
+from .utils import make_dense
def get_weight(x, y=None, perc=95):
@@ -22,6 +25,14 @@ def get_weight(x, y=None, perc=95):
def leastsq_NxN(x, y, fit_offset=False, perc=None, constraint_positive_offset=True):
"""Solves least squares X*b=Y for b."""
+
+ warnings.warn(
+ "`leastsq_NxN` is deprecated since scVelo v0.2.4 and will be removed in a "
+ "future version. Please use `LinearRegression` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
if perc is not None:
if not fit_offset and isinstance(perc, (list, tuple)):
perc = perc[1]
@@ -32,13 +43,13 @@ def leastsq_NxN(x, y, fit_offset=False, perc=None, constraint_positive_offset=Tr
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- xx_ = prod_sum_obs(x, x)
- xy_ = prod_sum_obs(x, y)
+ xx_ = prod_sum(x, x, axis=0)
+ xy_ = prod_sum(x, y, axis=0)
if fit_offset:
- n_obs = x.shape[0] if weights is None else sum_obs(weights)
- x_ = sum_obs(x) / n_obs
- y_ = sum_obs(y) / n_obs
+ n_obs = x.shape[0] if weights is None else sum(weights, axis=0)
+ x_ = sum(x, axis=0) / n_obs
+ y_ = sum(y, axis=0) / n_obs
gamma = (xy_ / n_obs - x_ * y_) / (xx_ / n_obs - x_ ** 2)
offset = y_ - gamma * x_
diff --git a/scvelo/tools/paga.py b/scvelo/tools/paga.py
index 81e52530..dbb7edb0 100644
--- a/scvelo/tools/paga.py
+++ b/scvelo/tools/paga.py
@@ -1,16 +1,17 @@
-# This is adapted from https://github.com/theislab/paga
-from .. import settings
-from .. import logging as logg
-from .utils import strings_to_categoricals
-from .velocity_graph import vals_to_csr
-from .velocity_pseudotime import velocity_pseudotime
-from .rank_velocity_genes import velocity_clusters
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
-from pandas.api.types import is_categorical
+
from scanpy.tools._paga import PAGA
+# This is adapted from https://github.com/theislab/paga
+from scvelo import logging as logg
+from scvelo import settings
+from .rank_velocity_genes import velocity_clusters
+from .utils import strings_to_categoricals
+from .velocity_graph import vals_to_csr
+from .velocity_pseudotime import velocity_pseudotime
+
def get_igraph_from_adjacency(adjacency, directed=None):
"""Get igraph graph from adjacency matrix."""
@@ -25,7 +26,7 @@ def get_igraph_from_adjacency(adjacency, directed=None):
g.add_edges(list(zip(sources, targets)))
try:
g.es["weight"] = weights
- except:
+ except Exception:
pass
if g.vcount() != adjacency.shape[0]:
logg.warn(
@@ -55,7 +56,9 @@ def get_sparse_from_igraph(graph, weight_attr=None):
def set_row_csr(csr, rows, value=0):
"""Set all nonzero elements to the given value. Useful to set to 0 mostly."""
for row in rows:
- csr.data[csr.indptr[row] : csr.indptr[row + 1]] = value
+ start = csr.indptr[row]
+ end = csr.indptr[row + 1]
+ csr.data[start:end] = value
if value == 0:
csr.eliminate_zeros()
@@ -218,27 +221,23 @@ def paga(
copy : `bool`, optional (default: `False`)
Copy `adata` before computation and return a copy.
Otherwise, perform computation inplace and return `None`.
+
Returns
-------
- **connectivities** : (adata.uns['connectivities'])
+ connectivities: `.uns`
The full adjacency matrix of the abstracted graph, weights correspond to
confidence in the connectivities of partitions.
- **connectivities_tree** : (adata.uns['connectivities_tree'])
+ connectivities_tree: `.uns`
The adjacency matrix of the tree-like subgraph that best explains the topology.
- **transitions_confidence** : (adata.uns['transitions_confidence'])
+ transitions_confidence: `.uns`
The adjacency matrix of the abstracted directed graph, weights correspond to
confidence in the transitions between partitions.
"""
+
if "neighbors" not in adata.uns:
raise ValueError(
"You need to run `pp.neighbors` first to compute a neighborhood graph."
)
- try:
- import igraph
- except ImportError:
- raise ImportError(
- "To run paga, you need to install `pip install python-igraph`"
- )
adata = adata.copy() if copy else adata
strings_to_categoricals(adata)
@@ -260,7 +259,7 @@ def paga(
priors = [p for p in [use_time_prior, root_key, end_key] if p in adata.obs.keys()]
logg.info(
- f"running PAGA",
+ "running PAGA",
f"using priors: {priors}" if len(priors) > 0 else "",
r=True,
)
diff --git a/scvelo/tools/rank_velocity_genes.py b/scvelo/tools/rank_velocity_genes.py
index 26741e86..324fd9cd 100644
--- a/scvelo/tools/rank_velocity_genes.py
+++ b/scvelo/tools/rank_velocity_genes.py
@@ -1,11 +1,11 @@
-from .. import settings
-from .. import logging as logg
+import numpy as np
+from scipy.sparse import issparse
+
+from scvelo import logging as logg
+from scvelo import settings
from .utils import strings_to_categoricals, vcorrcoef
from .velocity_pseudotime import velocity_pseudotime
-from scipy.sparse import issparse
-import numpy as np
-
def get_mean_var(X, ignore_zeros=False, perc=None):
data = X.data if issparse(X) else X
@@ -102,10 +102,10 @@ def velocity_clusters(
Returns
-------
- Returns or updates `data` with the attributes
velocity_clusters : `.obs`
Clusters obtained from applying louvain modularity on velocity expression.
- """
+ """ # noqa E501
+
adata = data.copy() if copy else data
logg.info("computing velocity clusters", r=True)
@@ -133,7 +133,7 @@ def velocity_clusters(
if "fit_likelihood" in adata.var.keys() and min_likelihood is not None:
tmp_filter &= adata.var["fit_likelihood"] > min_likelihood
- from .. import AnnData
+ from anndata import AnnData
vdata = AnnData(adata.layers[vkey][:, tmp_filter])
vdata.obs = adata.obs.copy()
@@ -217,7 +217,9 @@ def rank_velocity_genes(
.. code:: python
scv.tl.rank_velocity_genes(adata, groupby='clusters')
- scv.pl.scatter(adata, basis=adata.uns['rank_velocity_genes']['names']['Beta'][:3])
+ scv.pl.scatter(
+ adata, basis=adata.uns['rank_velocity_genes']['names']['Beta'][:3]
+ )
pd.DataFrame(adata.uns['rank_velocity_genes']['names']).head()
.. image:: https://user-images.githubusercontent.com/31883718/69626017-11c47980-1048-11ea-89f4-df3769df5ad5.png
@@ -255,13 +257,13 @@ def rank_velocity_genes(
Returns
-------
- Returns or updates `data` with the attributes
rank_velocity_genes : `.uns`
Structured array to be indexed by group id storing the gene
names. Ordered according to scores.
velocity_score : `.var`
Storing the score for each gene for each group. Ordered according to scores.
- """
+ """ # noqa E501
+
adata = data.copy() if copy else data
if groupby is None or groupby == "velocity_clusters":
@@ -314,9 +316,9 @@ def rank_velocity_genes(
tmp_filter &= dispersions > min_dispersion
if "fit_likelihood" in adata.var.keys():
- l = adata.var["fit_likelihood"]
+ fit_likelihood = adata.var["fit_likelihood"]
min_likelihood = 0.1 if min_likelihood is None else min_likelihood
- tmp_filter &= l > min_likelihood
+ tmp_filter &= fit_likelihood > min_likelihood
X = adata[:, tmp_filter].layers[vkey]
groups, groups_masks = select_groups(adata, key=groupby)
@@ -354,7 +356,7 @@ def rank_velocity_genes(
all_names = rankings_gene_names.T.flatten()
all_scores = rankings_gene_scores.T.flatten()
- vscore = np.zeros(adata.n_vars, dtype=np.int)
+ vscore = np.zeros(adata.n_vars, dtype=int)
for i, name in enumerate(adata.var_names):
if name in all_names:
vscore[i] = all_scores[np.where(name == all_names)[0][0]]
diff --git a/scvelo/tools/run.py b/scvelo/tools/run.py
index d4b3d8d0..b807b047 100644
--- a/scvelo/tools/run.py
+++ b/scvelo/tools/run.py
@@ -1,5 +1,5 @@
-from ..preprocessing import filter_and_normalize, moments
-from . import velocity, velocity_graph, velocity_embedding
+from scvelo.preprocessing import filter_and_normalize, moments
+from . import velocity, velocity_embedding, velocity_graph
def run_all(
@@ -38,7 +38,8 @@ def run_all(
def convert_to_adata(vlm, basis=None):
from collections import OrderedDict
- from .. import AnnData
+
+ from anndata import AnnData
X = (
vlm.S_norm.T
@@ -90,10 +91,11 @@ def convert_to_adata(vlm, basis=None):
def convert_to_loom(adata, basis=None):
- from scipy.sparse import issparse
- import numpy as np
import velocyto
+ import numpy as np
+ from scipy.sparse import issparse
+
class VelocytoLoom(velocyto.VelocytoLoom):
def __init__(self, adata, basis=None):
kwargs = {"dtype": np.float64, "order": "C"}
@@ -118,7 +120,7 @@ def __init__(self, adata, basis=None):
self.initial_cell_size = self.S.sum(0)
self.initial_Ucell_size = self.U.sum(0)
- from ..preprocessing.utils import not_yet_normalized
+ from scvelo.preprocessing.utils import not_yet_normalized
if not not_yet_normalized(adata.layers["spliced"]):
self.S_sz = self.S
@@ -308,9 +310,9 @@ def run_all(
def test():
- from ..datasets import simulation
+ from scvelo.datasets import simulation
+ from scvelo.logging import print_version
from .velocity_graph import velocity_graph
- from ..logging import print_version
print_version()
adata = simulation(n_obs=300, n_vars=30)
diff --git a/scvelo/tools/score_genes_cell_cycle.py b/scvelo/tools/score_genes_cell_cycle.py
index 3a26d5ce..9f5b0933 100644
--- a/scvelo/tools/score_genes_cell_cycle.py
+++ b/scvelo/tools/score_genes_cell_cycle.py
@@ -1,8 +1,7 @@
-from .. import logging as logg
-
-import pandas as pd
import numpy as np
+import pandas as pd
+from scvelo import logging as logg
# fmt: off
s_genes_list = \
@@ -32,10 +31,12 @@ def get_phase_marker_genes(adata):
adata
The annotated data matrix.
phase
+
Returns
-------
List of S and/or G2M phase marker genes.
"""
+
name, gene_names = adata.var_names[0], adata.var_names
up, low = name.isupper(), name.islower()
s_genes_list_ = [
@@ -70,16 +71,17 @@ def score_genes_cell_cycle(adata, s_genes=None, g2m_genes=None, copy=False, **kw
**kwargs
Are passed to :func:`~scanpy.tl.score_genes`. `ctrl_size` is not
possible, as it's set as `min(len(s_genes), len(g2m_genes))`.
+
Returns
-------
- Depending on `copy`, returns or updates `adata` with the following fields.
- **S_score** : `adata.obs`, dtype `object`
+ S_score: `adata.obs`, dtype `object`
The score for S phase for each cell.
- **G2M_score** : `adata.obs`, dtype `object`
+ G2M_score: `adata.obs`, dtype `object`
The score for G2M phase for each cell.
- **phase** : `adata.obs`, dtype `object`
+ phase: `adata.obs`, dtype `object`
The cell cycle phase (`S`, `G2M` or `G1`) for each cell.
"""
+
logg.info("calculating cell cycle phase")
from scanpy.tools._score_genes import score_genes
diff --git a/scvelo/tools/terminal_states.py b/scvelo/tools/terminal_states.py
index 99a9fc23..5f740dcf 100644
--- a/scvelo/tools/terminal_states.py
+++ b/scvelo/tools/terminal_states.py
@@ -1,13 +1,13 @@
-from .. import settings
-from .. import logging as logg
-from ..preprocessing.moments import get_connectivities
-from ..preprocessing.neighbors import verify_neighbors
-from .velocity_graph import VelocityGraph
-from .transition_matrix import transition_matrix
-from .utils import scale, groups_to_bool, strings_to_categoricals, get_plasticity_score
-
-from scipy.sparse import linalg, csr_matrix, issparse
import numpy as np
+from scipy.sparse import csr_matrix, issparse, linalg
+
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.preprocessing.moments import get_connectivities
+from scvelo.preprocessing.neighbors import verify_neighbors
+from .transition_matrix import transition_matrix
+from .utils import get_plasticity_score, groups_to_bool, scale, strings_to_categoricals
+from .velocity_graph import VelocityGraph
def cell_fate(
@@ -37,12 +37,12 @@ def cell_fate(
Returns
-------
- Returns or updates `adata` with the attributes
cell_fate: `.obs`
most likely cell fate for each individual cell
cell_fate_confidence: `.obs`
confidence of transitioning to the assigned fate
"""
+
adata = data.copy() if copy else data
logg.info("computing cell fates", r=True)
@@ -56,8 +56,7 @@ def cell_fate(
_adata.uns["velocity_graph_neg"] = vgraph.graph_neg
T = transition_matrix(_adata, self_transitions=self_transitions)
- I = np.eye(_adata.n_obs)
- fate = np.linalg.inv(I - T)
+ fate = np.linalg.inv(np.eye(_adata.n_obs) - T)
if issparse(T):
fate = fate.A
cell_fates = np.array(_adata.obs[groupby][fate.argmax(1)])
@@ -105,12 +104,12 @@ def cell_origin(
Returns
-------
- Returns or updates `adata` with the attributes
cell_origin: `.obs`
most likely cell origin for each individual cell
cell_origin_confidence: `.obs`
confidence of coming from assigned origin
"""
+
adata = data.copy() if copy else data
logg.info("computing cell fates", r=True)
@@ -124,8 +123,7 @@ def cell_origin(
_adata.uns["velocity_graph_neg"] = vgraph.graph_neg
T = transition_matrix(_adata, self_transitions=self_transitions, backward=True)
- I = np.eye(_adata.n_obs)
- fate = np.linalg.inv(I - T)
+ fate = np.linalg.inv(np.eye(_adata.n_obs) - T)
if issparse(T):
fate = fate.A
cell_fates = np.array(_adata.obs[groupby][fate.argmax(1)])
@@ -167,7 +165,7 @@ def eigs(T, k=10, eps=1e-3, perc=None, random_state=None, v0=None):
eigvecs = np.clip(eigvecs, 0, ubs)
eigvecs /= eigvecs.max(0)
- except:
+ except Exception:
eigvals, eigvecs = np.empty(0), np.zeros(shape=(T.shape[0], 0))
return eigvals, eigvecs
@@ -222,7 +220,7 @@ def terminal_states(
which is given by left eigenvectors corresponding to an eigenvalue of 1, i.e.
.. math::
- μ^{\\textrm{end}}=μ^{\\textrm{end}} \\pi, \quad
+ μ^{\\textrm{end}}=μ^{\\textrm{end}} \\pi, \\quad
μ^{\\textrm{root}}=μ^{\\textrm{root}} \\pi^{\\small \\textrm{T}}.
.. code:: python
@@ -263,12 +261,12 @@ def terminal_states(
Returns
-------
- Returns or updates `data` with the attributes
root_cells: `.obs`
sparse matrix with transition probabilities.
end_points: `.obs`
sparse matrix with transition probabilities.
- """
+ """ # noqa E501
+
adata = data.copy() if copy else data
verify_neighbors(adata)
diff --git a/scvelo/tools/transition_matrix.py b/scvelo/tools/transition_matrix.py
index 2803ad7d..5279c86d 100644
--- a/scvelo/tools/transition_matrix.py
+++ b/scvelo/tools/transition_matrix.py
@@ -1,12 +1,12 @@
-from ..preprocessing.neighbors import get_connectivities, get_neighs
-from .utils import normalize
+import warnings
import numpy as np
import pandas as pd
-from scipy.spatial.distance import pdist, squareform
from scipy.sparse import csr_matrix, SparseEfficiencyWarning
+from scipy.spatial.distance import pdist, squareform
-import warnings
+from scvelo.preprocessing.neighbors import get_connectivities, get_neighs
+from .utils import normalize
warnings.simplefilter("ignore", SparseEfficiencyWarning)
@@ -36,7 +36,8 @@ def transition_matrix(
from the velocity graph :math:`\\pi_{ij}`, with row-normalization :math:`z_i` and
kernel width :math:`\\sigma` (scale parameter :math:`\\lambda = \\sigma^{-1}`).
- Alternatively, use :func:`cellrank.tl.transition_matrix` to account for uncertainty in the velocity estimates.
+ Alternatively, use :func:`cellrank.tl.transition_matrix` to account for uncertainty
+ in the velocity estimates.
Arguments
---------
@@ -72,6 +73,7 @@ def transition_matrix(
-------
Returns sparse matrix with transition probabilities.
"""
+
if f"{vkey}_graph" not in adata.uns:
raise ValueError(
"You need to run `tl.velocity_graph` first to compute cosine correlations."
@@ -185,11 +187,13 @@ def get_cell_transitions(
Set to `int` for reproducibility, otherwise `None` for a random seed.
**kwargs:
To be passed to tl.transition_matrix.
+
Returns
-------
Returns embedding coordinates (if basis is specified),
otherwise return indices of simulated cell transitions.
"""
+
np.random.seed(random_state)
if isinstance(starting_cell, str) and starting_cell in adata.obs_names:
starting_cell = adata.obs_names.get_loc(starting_cell)
@@ -207,6 +211,8 @@ def get_cell_transitions(
if n_neighbors is not None and n_neighbors < len(p):
idx = np.argsort(t.data)[::-1][:n_neighbors]
indices, p = indices[idx], p[idx]
+ if len(p) == 0:
+ indices, p = [X[-1]], [1]
p /= np.sum(p)
ix = np.random.choice(indices, p=p)
X.append(ix)
diff --git a/scvelo/tools/utils.py b/scvelo/tools/utils.py
index d06993d5..2465a102 100644
--- a/scvelo/tools/utils.py
+++ b/scvelo/tools/utils.py
@@ -1,8 +1,12 @@
+import warnings
+
+import numpy as np
+import pandas as pd
from scipy.sparse import csr_matrix, issparse
+
import matplotlib.pyplot as pl
-import pandas as pd
-import numpy as np
-import warnings
+
+from scvelo.core import l2_norm, prod_sum, sum
warnings.simplefilter("ignore")
@@ -12,9 +16,9 @@ def round(k, dec=2, as_str=None):
return [round(ki, dec) for ki in k]
if "e" in f"{k}":
k_str = f"{k}".split("e")
- result = f"{np.round(np.float(k_str[0]), dec)}1e{k_str[1]}"
- return f"{result}" if as_str else np.float(result)
- result = np.round(np.float(k), dec)
+ result = f"{np.round(float(k_str[0]), dec)}1e{k_str[1]}"
+ return f"{result}" if as_str else float(result)
+ result = np.round(float(k), dec)
return f"{result}" if as_str else result
@@ -31,72 +35,109 @@ def make_dense(X):
def sum_obs(A):
"""summation over axis 0 (obs) equivalent to np.sum(A, 0)"""
- if issparse(A):
- return A.sum(0).A1
- else:
- return np.einsum("ij -> j", A) if A.ndim > 1 else np.sum(A)
+
+ warnings.warn(
+ "`sum_obs` is deprecated since scVelo v0.2.4 and will be removed in a future "
+ "version. Please use `sum(A, axis=0)` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return sum(A, axis=0)
def sum_var(A):
"""summation over axis 1 (var) equivalent to np.sum(A, 1)"""
- if issparse(A):
- return A.sum(1).A1
- else:
- return np.sum(A, axis=1) if A.ndim > 1 else np.sum(A)
+
+ warnings.warn(
+ "`sum_var` is deprecated since scVelo v0.2.4 and will be removed in a future "
+ "version. Please use `sum(A, axis=1)` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return sum(A, axis=1)
def prod_sum_obs(A, B):
"""dot product and sum over axis 0 (obs) equivalent to np.sum(A * B, 0)"""
- if issparse(A):
- return A.multiply(B).sum(0).A1
- else:
- return np.einsum("ij, ij -> j", A, B) if A.ndim > 1 else (A * B).sum()
+
+ warnings.warn(
+ "`prod_sum_obs` is deprecated since scVelo v0.2.4 and will be removed in a "
+ "future version. Please use `prod_sum(A, B, axis=0)` from `scvelo/core/` "
+ "instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return prod_sum(A, B, axis=0)
def prod_sum_var(A, B):
"""dot product and sum over axis 1 (var) equivalent to np.sum(A * B, 1)"""
- if issparse(A):
- return A.multiply(B).sum(1).A1
- else:
- return np.einsum("ij, ij -> i", A, B) if A.ndim > 1 else (A * B).sum()
+
+ warnings.warn(
+ "`prod_sum_var` is deprecated since scVelo v0.2.4 and will be removed in a "
+ "future version. Please use `prod_sum(A, B, axis=1)` from `scvelo/core/` "
+ "instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return prod_sum(A, B, axis=1)
def norm(A):
"""computes the L2-norm along axis 1
(e.g. genes or embedding dimensions) equivalent to np.linalg.norm(A, axis=1)
"""
- if issparse(A):
- return np.sqrt(A.multiply(A).sum(1).A1)
- else:
- return np.sqrt(np.einsum("ij, ij -> i", A, A) if A.ndim > 1 else np.sum(A * A))
+
+ warnings.warn(
+ "`norm` is deprecated since scVelo v0.2.4 and will be removed in a future "
+ "version. Please use `l2_norm(A, axis=1)` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return l2_norm(A, axis=1)
def vector_norm(x):
"""computes the L2-norm along axis 1
(e.g. genes or embedding dimensions) equivalent to np.linalg.norm(A, axis=1)
"""
- return np.sqrt(np.einsum("i, i -> ", x, x))
+
+ warnings.warn(
+ "`vector_norm` is deprecated since scVelo v0.2.4 and will be removed in a "
+ "future version. Please use `l2_norm(x)` from `scvelo/core/` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return l2_norm(x)
def R_squared(residual, total):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- r2 = np.ones(residual.shape[1]) - prod_sum_obs(
- residual, residual
- ) / prod_sum_obs(total, total)
+ r2 = np.ones(residual.shape[1]) - prod_sum(
+ residual, residual, axis=0
+ ) / prod_sum(total, total, axis=0)
r2[np.isnan(r2)] = 0
return r2
def cosine_correlation(dX, Vi):
dx = dX - dX.mean(-1)[:, None]
- Vi_norm = vector_norm(Vi)
+ Vi_norm = l2_norm(Vi, axis=0)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if Vi_norm == 0:
result = np.zeros(dx.shape[0])
else:
- result = np.einsum("ij, j", dx, Vi) / (norm(dx) * Vi_norm)[None, :]
+ result = (
+ np.einsum("ij, j", dx, Vi) / (l2_norm(dx, axis=1) * Vi_norm)[None, :]
+ )
return result
@@ -119,12 +160,12 @@ def scale(X, min=0, max=1):
def get_indices(dist, n_neighbors=None, mode_neighbors="distances"):
- from ..preprocessing.neighbors import compute_connectivities_umap
+ from scvelo.preprocessing.neighbors import compute_connectivities_umap
D = dist.copy()
D.data += 1e-6
- n_counts = sum_var(D > 0)
+ n_counts = sum(D > 0, axis=1)
n_neighbors = (
n_counts.min() if n_neighbors is None else min(n_counts.min(), n_neighbors)
)
@@ -221,8 +262,8 @@ def randomized_velocity(adata, vkey="velocity", add_key="velocity_random"):
)
adata.layers[add_key] = V_rnd
- from .velocity_graph import velocity_graph
from .velocity_embedding import velocity_embedding
+ from .velocity_graph import velocity_graph
velocity_graph(adata, vkey=add_key)
velocity_embedding(adata, vkey=add_key, autoscale=False)
@@ -248,8 +289,8 @@ def str_to_int(item):
def strings_to_categoricals(adata):
"""Transform string annotations to categoricals."""
- from pandas.api.types import is_string_dtype, is_integer_dtype, is_bool_dtype
from pandas import Categorical
+ from pandas.api.types import is_bool_dtype, is_integer_dtype, is_string_dtype
def is_valid_dtype(values):
return (
@@ -341,15 +382,15 @@ def cutoff_small_velocities(
adata.layers[key_added] = csr_matrix(W).multiply(adata.layers[vkey]).tocsr()
- from .velocity_graph import velocity_graph
from .velocity_embedding import velocity_embedding
+ from .velocity_graph import velocity_graph
velocity_graph(adata, vkey=key_added, approx=True)
velocity_embedding(adata, vkey=key_added)
def make_unique_list(key, allow_array=False):
- from pandas import unique, Index
+ from pandas import Index, unique
if isinstance(key, Index):
key = key.tolist()
@@ -375,7 +416,8 @@ def test_bimodality(x, bins=30, kde=True, plot=False):
)
idx = int(bins / 2) - 2
- idx += np.argmin(kde_grid[idx : idx + 4])
+ end = idx + 4
+ idx += np.argmin(kde_grid[idx:end])
peak_0 = kde_grid[:idx].argmax()
peak_1 = kde_grid[idx:].argmax()
@@ -477,7 +519,7 @@ def indices_to_bool(indices, n):
def convolve(adata, x):
- from ..preprocessing.neighbors import get_connectivities
+ from scvelo.preprocessing.neighbors import get_connectivities
conn = get_connectivities(adata)
if isinstance(x, str) and x in adata.layers.keys():
diff --git a/scvelo/tools/velocity.py b/scvelo/tools/velocity.py
index 3e26826e..eadc53fb 100644
--- a/scvelo/tools/velocity.py
+++ b/scvelo/tools/velocity.py
@@ -1,11 +1,13 @@
-from .. import settings
-from .. import logging as logg
-from ..preprocessing.moments import moments, second_order_moments, get_connectivities
-from .optimization import leastsq_NxN, leastsq_generalized, maximum_likelihood
-from .utils import R_squared, groups_to_bool, make_dense, strings_to_categoricals
+import warnings
import numpy as np
-import warnings
+
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.core import LinearRegression
+from scvelo.preprocessing.moments import moments, second_order_moments
+from .optimization import leastsq_generalized, maximum_likelihood
+from .utils import groups_to_bool, make_dense, R_squared, strings_to_categoricals
warnings.simplefilter(action="ignore", category=FutureWarning)
@@ -57,7 +59,11 @@ def compute_deterministic(self, fit_offset=False, perc=None):
subset = self._groups_for_fit
Ms = self._Ms if subset is None else self._Ms[subset]
Mu = self._Mu if subset is None else self._Mu[subset]
- self._offset, self._gamma = leastsq_NxN(Ms, Mu, fit_offset, perc)
+
+ lr = LinearRegression(fit_intercept=fit_offset, percentile=perc)
+ lr.fit(Ms, Mu)
+ self._offset = lr.intercept_
+ self._gamma = lr.coef_
if self._constrain_ratio is not None:
if np.size(self._constrain_ratio) < 2:
@@ -72,7 +78,11 @@ def compute_deterministic(self, fit_offset=False, perc=None):
# velocity genes
if self._r2_adjusted:
- _offset, _gamma = leastsq_NxN(Ms, Mu, fit_offset)
+ lr = LinearRegression(fit_intercept=fit_offset)
+ lr.fit(Ms, Mu)
+ _offset = lr.intercept_
+ _gamma = lr.coef_
+
_residual = self._Mu - _gamma * self._Ms
if fit_offset:
_residual -= _offset
@@ -121,7 +131,10 @@ def compute_stochastic(
var_ss = 2 * _Mss - _Ms
cov_us = 2 * _Mus + _Mu
- _offset2, _gamma2 = leastsq_NxN(var_ss, cov_us, fit_offset2)
+ lr = LinearRegression(fit_intercept=fit_offset2)
+ lr.fit(var_ss, cov_us)
+ _offset2 = lr.intercept_
+ _gamma2 = lr.coef_
# initialize covariance matrix
res_std = _residual.std(0)
@@ -286,14 +299,12 @@ def velocity(
Returns
-------
- Returns or updates `adata` with the attributes
velocity: `.layers`
velocity vectors for each individual cell
- variance_velocity: `.layers`
- velocity vectors for the cell variances
- velocity_offset, velocity_beta, velocity_gamma, velocity_r2: `.var`
+ velocity_genes, velocity_beta, velocity_gamma, velocity_r2: `.var`
parameters
- """
+ """ # noqa E501
+
adata = data.copy() if copy else data
if not use_raw and "Ms" not in adata.layers.keys():
moments(adata)
@@ -310,7 +321,7 @@ def velocity(
)
if mode in {"dynamical", "dynamical_residuals"}:
- from .dynamical_model_utils import get_reads, get_vars, get_divergence
+ from .dynamical_model_utils import get_divergence, get_reads, get_vars
gene_subset = ~np.isnan(adata.var["fit_alpha"].values)
vdata = adata[:, gene_subset]
@@ -492,6 +503,7 @@ def velocity_genes(
velocity_genes: `.var`
genes to be used for further velocity analysis (velocity graph and embedding)
"""
+
adata = data.copy() if copy else data
if f"{vkey}_genes" not in adata.var.keys():
velocity(adata, vkey)
@@ -512,8 +524,8 @@ def velocity_genes(
if np.sum(vgenes) < 2:
logg.warn(
- f"You seem to have very low signal in splicing dynamics.\n"
- f"Consider reducing the thresholds and be cautious with interpretations.\n"
+ "You seem to have very low signal in splicing dynamics.\n"
+ "Consider reducing the thresholds and be cautious with interpretations.\n"
)
adata.var[f"{vkey}_genes"] = vgenes
diff --git a/scvelo/tools/velocity_confidence.py b/scvelo/tools/velocity_confidence.py
index 0e60d54c..85db17a6 100644
--- a/scvelo/tools/velocity_confidence.py
+++ b/scvelo/tools/velocity_confidence.py
@@ -1,10 +1,11 @@
-from .. import logging as logg
-from ..preprocessing.neighbors import get_neighs
-from .utils import prod_sum_var, norm, get_indices, random_subsample
-from .transition_matrix import transition_matrix
-
import numpy as np
+from scvelo import logging as logg
+from scvelo.core import l2_norm, prod_sum
+from scvelo.preprocessing.neighbors import get_neighs
+from .transition_matrix import transition_matrix
+from .utils import get_indices, random_subsample
+
def velocity_confidence(data, vkey="velocity", copy=False):
"""Computes confidences of velocities.
@@ -29,12 +30,12 @@ def velocity_confidence(data, vkey="velocity", copy=False):
Returns
-------
- Returns or updates `adata` with the attributes
velocity_length: `.obs`
Length of the velocity vectors for each individual cell
velocity_confidence: `.obs`
Confidence for each cell
- """
+ """ # noqa E501
+
adata = data.copy() if copy else data
if vkey not in adata.layers.keys():
raise ValueError("You need to run `tl.velocity` first.")
@@ -50,7 +51,7 @@ def velocity_confidence(data, vkey="velocity", copy=False):
V = V[:, tmp_filter]
V -= V.mean(1)[:, None]
- V_norm = norm(V)
+ V_norm = l2_norm(V, axis=1)
R = np.zeros(adata.n_obs)
indices = get_indices(dist=get_neighs(adata, "distances"))[0]
@@ -58,7 +59,8 @@ def velocity_confidence(data, vkey="velocity", copy=False):
Vi_neighs = V[indices[i]]
Vi_neighs -= Vi_neighs.mean(1)[:, None]
R[i] = np.mean(
- np.einsum("ij, j", Vi_neighs, V[i]) / (norm(Vi_neighs) * V_norm[i])[None, :]
+ np.einsum("ij, j", Vi_neighs, V[i])
+ / (l2_norm(Vi_neighs, axis=1) * V_norm[i])[None, :]
)
adata.obs[f"{vkey}_length"] = V_norm.round(2)
@@ -89,10 +91,10 @@ def velocity_confidence_transition(data, vkey="velocity", scale=10, copy=False):
Returns
-------
- Returns or updates `adata` with the attributes
velocity_confidence_transition: `.obs`
Confidence of transition for each cell
"""
+
adata = data.copy() if copy else data
if vkey not in adata.layers.keys():
raise ValueError("You need to run `tl.velocity` first.")
@@ -114,10 +116,10 @@ def velocity_confidence_transition(data, vkey="velocity", scale=10, copy=False):
dX -= dX.mean(1)[:, None]
V -= V.mean(1)[:, None]
- norms = norm(dX) * norm(V)
+ norms = l2_norm(dX, axis=1) * l2_norm(V, axis=1)
norms += norms == 0
- adata.obs[f"{vkey}_confidence_transition"] = prod_sum_var(dX, V) / norms
+ adata.obs[f"{vkey}_confidence_transition"] = prod_sum(dX, V, axis=1) / norms
logg.hint(f"added '{vkey}_confidence_transition' (adata.obs)")
@@ -130,8 +132,8 @@ def score_robustness(
adata = data.copy() if copy else data
if adata_subset is None:
- from ..preprocessing.moments import moments
- from ..preprocessing.neighbors import neighbors
+ from scvelo.preprocessing.moments import moments
+ from scvelo.preprocessing.neighbors import neighbors
from .velocity import velocity
logg.switch_verbosity("off")
@@ -147,8 +149,10 @@ def score_robustness(
V = adata[subset].layers[vkey]
V_subset = adata_subset.layers[vkey]
- score = np.nan * (subset == False)
- score[subset] = prod_sum_var(V, V_subset) / (norm(V) * norm(V_subset))
+ score = np.nan * (subset is False)
+ score[subset] = prod_sum(V, V_subset, axis=1) / (
+ l2_norm(V, axis=1) * l2_norm(V_subset, axis=1)
+ )
adata.obs[f"{vkey}_score_robustness"] = score
return adata_subset if copy else None
diff --git a/scvelo/tools/velocity_embedding.py b/scvelo/tools/velocity_embedding.py
index a5521fc4..d62c46ef 100644
--- a/scvelo/tools/velocity_embedding.py
+++ b/scvelo/tools/velocity_embedding.py
@@ -1,11 +1,12 @@
-from .. import settings
-from .. import logging as logg
-from .utils import norm
-from .transition_matrix import transition_matrix
+import warnings
-from scipy.sparse import issparse
import numpy as np
-import warnings
+from scipy.sparse import issparse
+
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.core import l2_norm
+from .transition_matrix import transition_matrix
def quiver_autoscale(X_emb, V_emb):
@@ -45,13 +46,15 @@ def velocity_embedding(
"""Projects the single cell velocities into any embedding.
Given normalized difference of the embedding positions
- :math:`\\tilde \\delta_{ij} = \\frac{x_j-x_i}{\\left\\lVert x_j-x_i \\right\\rVert}`.
+ :math:
+ `\\tilde \\delta_{ij} = \\frac{x_j-x_i}{\\left\\lVert x_j-x_i \\right\\rVert}`.
the projections are obtained as expected displacements with respect to the
transition matrix :math:`\\tilde \\pi_{ij}` as
.. math::
\\tilde \\nu_i = E_{\\tilde \\pi_{i\\cdot}} [\\tilde \\delta_{i \\cdot}]
- = \\sum_{j \\neq i} \left( \\tilde \\pi_{ij} - \\frac1n \\right) \\tilde \\delta_{ij}.
+ = \\sum_{j \\neq i} \\left( \\tilde \\pi_{ij} - \\frac1n \\right) \\tilde \\
+ delta_{ij}.
Arguments
@@ -87,10 +90,10 @@ def velocity_embedding(
Returns
-------
- Returns or updates `adata` with the attributes
- velocity_basis: `.obsm`
- coordinates of velocity projection on embedding
+ velocity_umap: `.obsm`
+ coordinates of velocity projection on embedding (e.g., basis='umap')
"""
+
adata = data.copy() if copy else data
if basis is None:
@@ -107,9 +110,12 @@ def velocity_embedding(
if direct_pca_projection and "pca" in basis:
logg.warn(
- "Directly projecting velocities into PCA space is for exploratory analysis on principal components.\n"
- " It does not reflect the actual velocity field from high dimensional gene expression space.\n"
- " To visualize velocities, consider applying `direct_pca_projection=False`.\n"
+ "Directly projecting velocities into PCA space is for exploratory analysis "
+ "on principal components.\n"
+ " It does not reflect the actual velocity field from high "
+ "dimensional gene expression space.\n"
+ " To visualize velocities, consider applying "
+ "`direct_pca_projection=False`.\n"
)
logg.info("computing velocity embedding", r=True)
@@ -157,7 +163,7 @@ def velocity_embedding(
indices = T[i].indices
dX = X_emb[indices] - X_emb[i, None] # shape (n_neighbors, 2)
if not retain_scale:
- dX /= norm(dX)[:, None]
+ dX /= l2_norm(dX)[:, None]
dX[np.isnan(dX)] = 0 # zero diff in a steady-state
probs = TA[i, indices] if densify else T[i].data
V_emb[i] = probs.dot(dX) - probs.mean() * dX.sum(0)
@@ -171,7 +177,7 @@ def velocity_embedding(
delta = T.dot(X[:, vgenes]) - X[:, vgenes]
if issparse(delta):
delta = delta.A
- cos_proj = (V * delta).sum(1) / norm(delta)
+ cos_proj = (V * delta).sum(1) / l2_norm(delta)
V_emb *= np.clip(cos_proj[:, None] * 10, 0, 1)
if autoscale:
diff --git a/scvelo/tools/velocity_graph.py b/scvelo/tools/velocity_graph.py
index 73ce76f6..a1a5be4b 100644
--- a/scvelo/tools/velocity_graph.py
+++ b/scvelo/tools/velocity_graph.py
@@ -1,13 +1,21 @@
-from .. import settings
-from .. import logging as logg
-from ..preprocessing.neighbors import pca, neighbors, verify_neighbors
-from ..preprocessing.neighbors import get_neighs, get_n_neighs
-from ..preprocessing.moments import get_moments
-from .utils import cosine_correlation, get_indices, get_iterative_indices, norm
-from .velocity import velocity
+import os
-from scipy.sparse import coo_matrix, issparse
import numpy as np
+from scipy.sparse import coo_matrix, issparse
+
+from scvelo import logging as logg
+from scvelo import settings
+from scvelo.core import get_n_jobs, l2_norm, parallelize
+from scvelo.preprocessing.moments import get_moments
+from scvelo.preprocessing.neighbors import (
+ get_n_neighs,
+ get_neighs,
+ neighbors,
+ pca,
+ verify_neighbors,
+)
+from .utils import cosine_correlation, get_indices, get_iterative_indices
+from .velocity import velocity
def vals_to_csr(vals, rows, cols, shape, split_negative=False):
@@ -83,8 +91,10 @@ def __init__(
self.V_raw = np.array(self.V)
self.sqrt_transform = sqrt_transform
- if self.sqrt_transform is None and f"{vkey}_params" in adata.uns.keys():
- self.sqrt_transform = adata.uns[f"{vkey}_params"]["mode"] == "stochastic"
+ uns_key = f"{vkey}_params"
+ if self.sqrt_transform is None:
+ if uns_key in adata.uns.keys() and "mode" in adata.uns[uns_key]:
+ self.sqrt_transform = adata.uns[uns_key]["mode"] == "stochastic"
if self.sqrt_transform:
self.V = np.sqrt(np.abs(self.V)) * np.sign(self.V)
self.V -= np.nanmean(self.V, axis=1)[:, None]
@@ -125,7 +135,7 @@ def __init__(
mode="connectivity"
).indices.reshape((-1, n_neighbors + 1))
else:
- from .. import Neighbors
+ from scvelo import Neighbors
neighs = Neighbors(adata)
neighs.compute_neighbors(
@@ -156,21 +166,51 @@ def __init__(
self.report = report
self.adata = adata
- def compute_cosines(self):
- vals, rows, cols, uncertainties, n_obs = [], [], [], [], self.X.shape[0]
- progress = logg.ProgressReporter(n_obs)
+ def compute_cosines(self, n_jobs=None, backend="loky"):
+ n_jobs = get_n_jobs(n_jobs=n_jobs)
+
+ n_obs = self.X.shape[0]
+
+ # TODO: Use batches and vectorize calculation of dX in self._calculate_cosines
+ res = parallelize(
+ self._compute_cosines,
+ range(self.X.shape[0]),
+ n_jobs=n_jobs,
+ unit="cells",
+ backend=backend,
+ )()
+ uncertainties, vals, rows, cols = map(_flatten, zip(*res))
+
+ vals = np.hstack(vals)
+ vals[np.isnan(vals)] = 0
+
+ self.graph, self.graph_neg = vals_to_csr(
+ vals, rows, cols, shape=(n_obs, n_obs), split_negative=True
+ )
+ if self.compute_uncertainties:
+ uncertainties = np.hstack(uncertainties)
+ uncertainties[np.isnan(uncertainties)] = 0
+ self.uncertainties = vals_to_csr(
+ uncertainties, rows, cols, shape=(n_obs, n_obs), split_negative=False
+ )
+ self.uncertainties.eliminate_zeros()
+
+ confidence = self.graph.max(1).A.flatten()
+ self.self_prob = np.clip(np.percentile(confidence, 98) - confidence, 0, 1)
+ def _compute_cosines(self, obs_idx, queue):
+ vals, rows, cols, uncertainties = [], [], [], []
if self.compute_uncertainties:
- m = get_moments(self.adata, np.sign(self.V_raw), second_order=True)
+ moments = get_moments(self.adata, np.sign(self.V_raw), second_order=True)
- for i in range(n_obs):
- if self.V[i].max() != 0 or self.V[i].min() != 0:
+ for obs_id in obs_idx:
+ if self.V[obs_id].max() != 0 or self.V[obs_id].min() != 0:
neighs_idx = get_iterative_indices(
- self.indices, i, self.n_recurse_neighbors, self.max_neighs
+ self.indices, obs_id, self.n_recurse_neighbors, self.max_neighs
)
if self.t0 is not None:
- t0, t1 = self.t0[i], self.t1[i]
+ t0, t1 = self.t0[obs_id], self.t1[obs_id]
if t0 >= 0 and t1 > 0:
t1_idx = np.where(self.t0 == t1)[0]
if len(t1_idx) > len(neighs_idx):
@@ -180,39 +220,32 @@ def compute_cosines(self):
if len(t1_idx) > 0:
neighs_idx = np.unique(np.concatenate([neighs_idx, t1_idx]))
- dX = self.X[neighs_idx] - self.X[i, None] # 60% of runtime
+ dX = self.X[neighs_idx] - self.X[obs_id, None] # 60% of runtime
if self.sqrt_transform:
dX = np.sqrt(np.abs(dX)) * np.sign(dX)
- val = cosine_correlation(dX, self.V[i]) # 40% of runtime
+ val = cosine_correlation(dX, self.V[obs_id]) # 40% of runtime
if self.compute_uncertainties:
- dX /= norm(dX)[:, None]
- uncertainties.extend(np.nansum(dX ** 2 * m[i][None, :], 1))
+ dX /= l2_norm(dX)[:, None]
+ uncertainties.extend(
+ np.nansum(dX ** 2 * moments[obs_id][None, :], 1)
+ )
vals.extend(val)
- rows.extend(np.ones(len(neighs_idx)) * i)
+ rows.extend(np.ones(len(neighs_idx)) * obs_id)
cols.extend(neighs_idx)
- if self.report:
- progress.update()
- if self.report:
- progress.finish()
- vals = np.hstack(vals)
- vals[np.isnan(vals)] = 0
+ if queue is not None:
+ queue.put(1)
- self.graph, self.graph_neg = vals_to_csr(
- vals, rows, cols, shape=(n_obs, n_obs), split_negative=True
- )
- if self.compute_uncertainties:
- uncertainties = np.hstack(uncertainties)
- uncertainties[np.isnan(uncertainties)] = 0
- self.uncertainties = vals_to_csr(
- uncertainties, rows, cols, shape=(n_obs, n_obs), split_negative=False
- )
- self.uncertainties.eliminate_zeros()
+ if queue is not None:
+ queue.put(None)
- confidence = self.graph.max(1).A.flatten()
- self.self_prob = np.clip(np.percentile(confidence, 98) - confidence, 0, 1)
+ return uncertainties, vals, rows, cols
+
+
+def _flatten(iterable):
+ return [i for it in iterable for i in it]
def velocity_graph(
@@ -231,6 +264,8 @@ def velocity_graph(
approx=None,
mode_neighbors="distances",
copy=False,
+ n_jobs=None,
+ backend="loky",
):
"""Computes velocity graph based on cosine similarities.
@@ -241,7 +276,8 @@ def velocity_graph(
.. math::
\\pi_{ij} = \\cos\\angle(\\delta_{ij}, \\nu_i)
- = \\frac{\\delta_{ij}^T \\nu_i}{\\left\\lVert\\delta_{ij}\\right\\rVert \\left\\lVert \\nu_i \\right\\rVert}.
+ = \\frac{\\delta_{ij}^T \\nu_i}{\\left\\lVert\\delta_{ij}\\right\\rVert
+ \\left\\lVert \\nu_i \\right\\rVert}.
Arguments
---------
@@ -277,13 +313,18 @@ def velocity_graph(
'connectivities'. The latter yields a symmetric graph.
copy: `bool` (default: `False`)
Return a copy instead of writing to adata.
+ n_jobs: `int` or `None` (default: `None`)
+ Number of parallel jobs.
+ backend: `str` (default: "loky")
+ Backend used for multiprocessing. See :class:`joblib.Parallel` for valid
+ options.
Returns
-------
- Returns or updates `adata` with the attributes
velocity_graph: `.uns`
- sparse matrix with transition probabilities
+ sparse matrix with correlations of cell state transitions with velocities
"""
+
adata = data.copy() if copy else data
verify_neighbors(adata)
if vkey not in adata.layers.keys():
@@ -315,8 +356,11 @@ def velocity_graph(
f" on full expression space by not specifying basis.\n"
)
- logg.info("computing velocity graph", r=True)
- vgraph.compute_cosines()
+ n_jobs = get_n_jobs(n_jobs=n_jobs)
+ logg.info(
+ f"computing velocity graph (using {n_jobs}/{os.cpu_count()} cores)", r=True
+ )
+ vgraph.compute_cosines(n_jobs=n_jobs, backend=backend)
adata.uns[f"{vkey}_graph"] = vgraph.graph
adata.uns[f"{vkey}_graph_neg"] = vgraph.graph_neg
diff --git a/scvelo/tools/velocity_pseudotime.py b/scvelo/tools/velocity_pseudotime.py
index 9b10d8f8..f92e9b05 100644
--- a/scvelo/tools/velocity_pseudotime.py
+++ b/scvelo/tools/velocity_pseudotime.py
@@ -1,12 +1,13 @@
-from .. import logging as logg
-from .terminal_states import terminal_states
-from .utils import groups_to_bool, scale, strings_to_categoricals
-from ..preprocessing.moments import get_connectivities
-
import numpy as np
-from scipy.sparse import issparse, spdiags, linalg
+from scipy.sparse import issparse, linalg, spdiags
+
from scanpy.tools._dpt import DPT
+from scvelo import logging as logg
+from scvelo.preprocessing.moments import get_connectivities
+from .terminal_states import terminal_states
+from .utils import groups_to_bool, scale, strings_to_categoricals
+
def principal_curve(data, basis="pca", n_comps=4, clusters_list=None, copy=False):
"""Computes the principal curve
@@ -20,12 +21,13 @@ def principal_curve(data, basis="pca", n_comps=4, clusters_list=None, copy=False
Number of pricipal components to be used.
copy: `bool`, (default: `False`)
Return a copy instead of writing to adata.
+
Returns
-------
- Returns or updates `adata` with the attributes
principal_curve: `.uns`
dictionary containing `projections`, `ixsort` and `arclength`
"""
+
adata = data.copy() if copy else data
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
@@ -182,12 +184,12 @@ def velocity_pseudotime(
**kwargs:
Further arguments to pass to VPT (e.g. min_group_size, allow_kendall_tau_shift).
- Returns
+ Returns
-------
- Updates `adata` with the attributes
velocity_pseudotime: `.obs`
Velocity pseudotime obtained from velocity graph.
- """
+ """ # noqa E501
+
strings_to_categoricals(adata)
if root_key is None and "root_cells" in adata.obs.keys():
root0 = adata.obs["root_cells"][0]
diff --git a/scvelo/utils.py b/scvelo/utils.py
index c9d5eb79..c265ff31 100644
--- a/scvelo/utils.py
+++ b/scvelo/utils.py
@@ -1,24 +1,62 @@
-from .preprocessing.utils import show_proportions, cleanup
-from .preprocessing.utils import set_initial_size, get_initial_size
+from scvelo.core import (
+ clean_obs_names,
+ cleanup,
+ get_initial_size,
+ merge,
+ set_initial_size,
+ show_proportions,
+)
+from scvelo.plotting.simulation import compute_dynamics
+from scvelo.plotting.utils import (
+ clip,
+ interpret_colorkey,
+ is_categorical,
+ rgb_custom_colormap,
+)
+from scvelo.plotting.velocity_embedding_grid import compute_velocity_on_grid
+from scvelo.preprocessing.moments import get_moments
+from scvelo.preprocessing.neighbors import get_connectivities
+from scvelo.read_load import (
+ convert_to_ensembl,
+ convert_to_gene_names,
+ gene_info,
+ load_biomart,
+)
+from scvelo.tools.optimization import get_weight, leastsq
+from scvelo.tools.rank_velocity_genes import get_mean_var
+from scvelo.tools.run import convert_to_adata, convert_to_loom
+from scvelo.tools.score_genes_cell_cycle import get_phase_marker_genes
+from scvelo.tools.transition_matrix import get_cell_transitions
+from scvelo.tools.transition_matrix import transition_matrix as get_transition_matrix
+from scvelo.tools.utils import * # noqa
+from scvelo.tools.velocity_graph import vals_to_csr
-from .preprocessing.neighbors import get_connectivities
-from .preprocessing.moments import get_moments
-
-from .tools.utils import *
-from .tools.rank_velocity_genes import get_mean_var
-from .tools.run import convert_to_adata, convert_to_loom
-from .tools.optimization import leastsq, get_weight
-from .tools.velocity_graph import vals_to_csr
-from .tools.score_genes_cell_cycle import get_phase_marker_genes
-
-from .tools.transition_matrix import transition_matrix as get_transition_matrix
-from .tools.transition_matrix import get_cell_transitions
-
-from .plotting.utils import is_categorical, clip
-from .plotting.utils import interpret_colorkey, rgb_custom_colormap
-
-from .plotting.velocity_embedding_grid import compute_velocity_on_grid
-from .plotting.simulation import compute_dynamics
-
-from .read_load import clean_obs_names, merge, gene_info
-from .read_load import convert_to_gene_names, convert_to_ensembl, load_biomart
+__all__ = [
+ "cleanup",
+ "clean_obs_names",
+ "clip",
+ "compute_dynamics",
+ "compute_velocity_on_grid",
+ "convert_to_adata",
+ "convert_to_ensembl",
+ "convert_to_gene_names",
+ "convert_to_loom",
+ "gene_info",
+ "get_cell_transitions",
+ "get_connectivities",
+ "get_initial_size",
+ "get_mean_var",
+ "get_moments",
+ "get_phase_marker_genes",
+ "get_transition_matrix",
+ "get_weight",
+ "interpret_colorkey",
+ "is_categorical",
+ "leastsq",
+ "load_biomart",
+ "merge",
+ "rgb_custom_colormap",
+ "set_initial_size",
+ "show_proportions",
+ "vals_to_csr",
+]
diff --git a/setup.py b/setup.py
index e41d5a02..d75d6280 100644
--- a/setup.py
+++ b/setup.py
@@ -1,19 +1,25 @@
-from setuptools import setup, find_packages
from pathlib import Path
+from setuptools import find_packages, setup
+
+
+def read_requirements(req_path):
+ """Read abstract requirements."""
+ requirements = Path(req_path).read_text("utf-8").splitlines()
+ return [r.strip() for r in requirements if not r.startswith("-")]
+
+
setup(
name="scvelo",
use_scm_version=True,
setup_requires=["setuptools_scm"],
python_requires=">=3.6",
- install_requires=[
- l.strip() for l in Path("requirements.txt").read_text("utf-8").splitlines()
- ],
+ install_requires=read_requirements("requirements.txt"),
extras_require=dict(
louvain=["python-igraph", "louvain"],
hnswlib=["pybind11", "hnswlib"],
- dev=["black==20.8b1", "pre-commit==2.5.1"],
- docs=[r for r in Path("docs/requirements.txt").read_text("utf-8").splitlines()],
+ dev=read_requirements("requirements-dev.txt"),
+ docs=read_requirements("docs/requirements.txt"),
),
packages=find_packages(),
author="Volker Bergen",
@@ -40,6 +46,7 @@
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Bio-Informatics",
"Topic :: Scientific/Engineering :: Visualization",
],
diff --git a/tests/test_basic.py b/tests/test_basic.py
index dd950d92..bf837f97 100644
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -1,9 +1,10 @@
-import scvelo as scv
import numpy as np
+import scvelo as scv
+
def test_einsum():
- from scvelo.tools.utils import prod_sum_obs, prod_sum_var, norm
+ from scvelo.tools.utils import norm, prod_sum_obs, prod_sum_var
Ms, Mu = np.random.rand(5, 4), np.random.rand(5, 4)
assert np.allclose(prod_sum_obs(Ms, Mu), np.sum(Ms * Mu, 0))
@@ -58,7 +59,7 @@ def test_pipeline():
Ms, Mu = adata.layers["Ms"][0], adata.layers["Mu"][0]
Vs, Vd = adata.layers["velocity"][0], adata.layers["dynamical_velocity"][0]
- Vpca, Vgraph = adata.obsm["velocity_pca"][0], adata.uns["velocity_graph"].data[:5]
+ Vgraph = adata.uns["velocity_graph"].data[:5]
pars = adata[:, 0].var[["fit_alpha", "fit_gamma"]].values
assert np.allclose(Ms, [0.8269, 1.0772, 0.9396, 1.0846, 1.0398], rtol=1e-2)