diff --git a/.github/actions/setup-deps/action.yaml b/.github/actions/setup-deps/action.yaml
index e5a0794450c..cceae40f99c 100644
--- a/.github/actions/setup-deps/action.yaml
+++ b/.github/actions/setup-deps/action.yaml
@@ -62,6 +62,8 @@ inputs:
default: 'chemfiles-python>=0.9'
clustalw:
default: 'clustalw=2.1'
+ dask:
+ default: 'dask'
distopia:
default: 'distopia>=0.2.0'
h5py:
@@ -134,6 +136,7 @@ runs:
${{ inputs.biopython }}
${{ inputs.chemfiles-python }}
${{ inputs.clustalw }}
+ ${{ inputs.dask }}
${{ inputs.distopia }}
${{ inputs.gsd }}
${{ inputs.h5py }}
diff --git a/package/CHANGELOG b/package/CHANGELOG
index 750ddcf1cba..056282c6214 100644
--- a/package/CHANGELOG
+++ b/package/CHANGELOG
@@ -52,6 +52,8 @@ Fixes
* Fix groups.py doctests using sphinx directives (Issue #3925, PR #4374)
Enhancements
+ * Introduce parallelization API to `AnalysisBase` and to `analysis.rms.RMSD` class
+ (Issue #4158, PR #4304)
* Improve error message for `AtomGroup.unwrap()` when bonds are not present.(Issue #4436, PR #4642)
* Add `analysis.DSSP` module for protein secondary structure assignment, based on [pydssp](https://github.com/ShintaroMinami/PyDSSP)
* Added a tqdm progress bar for `MDAnalysis.analysis.pca.PCA.transform()`
diff --git a/package/MDAnalysis/analysis/__init__.py b/package/MDAnalysis/analysis/__init__.py
index 5983bb8b61c..25450b759b9 100644
--- a/package/MDAnalysis/analysis/__init__.py
+++ b/package/MDAnalysis/analysis/__init__.py
@@ -23,6 +23,7 @@
__all__ = [
'align',
+ 'backends',
'base',
'contacts',
'density',
@@ -45,6 +46,7 @@
'pca',
'psa',
'rdf',
+ 'results',
'rms',
'waterdynamics',
]
diff --git a/package/MDAnalysis/analysis/backends.py b/package/MDAnalysis/analysis/backends.py
new file mode 100644
index 00000000000..38917cb2ae7
--- /dev/null
+++ b/package/MDAnalysis/analysis/backends.py
@@ -0,0 +1,333 @@
+"""Analysis backends --- :mod:`MDAnalysis.analysis.backends`
+============================================================
+
+.. versionadded:: 2.8.0
+
+
+The :mod:`backends` module provides :class:`BackendBase` base class to
+implement custom execution backends for
+:meth:`MDAnalysis.analysis.base.AnalysisBase.run` and its
+subclasses.
+
+.. SeeAlso:: :ref:`parallel-analysis`
+
+.. _backends:
+
+Backends
+--------
+
+Three built-in backend classes are provided:
+
+* *serial*: :class:`BackendSerial`, that is equivalent to using no
+ parallelization and is the default
+
+* *multiprocessing*: :class:`BackendMultiprocessing` that supports
+ parallelization via standard Python :mod:`multiprocessing` module
+ and uses default :mod:`pickle` serialization
+
+* *dask*: :class:`BackendDask`, that uses the same process-based
+ parallelization as :class:`BackendMultiprocessing`, but different
+ serialization algorithm via `dask `_ (see `dask
+ serialization algorithms
+ `_ for details)
+
+Classes
+-------
+
+"""
+import warnings
+from typing import Callable
+from MDAnalysis.lib.util import is_installed
+
+
+class BackendBase:
+ """Base class for backend implementation.
+
+ Initializes an instance and performs checks for its validity, such as
+ ``n_workers`` and possibly other ones.
+
+ Parameters
+ ----------
+ n_workers : int
+ number of workers (usually, processes) over which the work is split
+
+ Examples
+ --------
+ .. code-block:: python
+
+ from MDAnalysis.analysis.backends import BackendBase
+
+ class ThreadsBackend(BackendBase):
+ def apply(self, func, computations):
+ from multiprocessing.dummy import Pool
+
+ with Pool(processes=self.n_workers) as pool:
+ results = pool.map(func, computations)
+ return results
+
+ import MDAnalysis as mda
+ from MDAnalysis.tests.datafiles import PSF, DCD
+ from MDAnalysis.analysis.rms import RMSD
+
+ u = mda.Universe(PSF, DCD)
+ ref = mda.Universe(PSF, DCD)
+
+ R = RMSD(u, ref)
+
+ n_workers = 2
+ backend = ThreadsBackend(n_workers=n_workers)
+ R.run(backend=backend, unsupported_backend=True)
+
+ .. warning::
+ Using `ThreadsBackend` above will lead to erroneous results, since it
+ is an educational example. Do not use it for real analysis.
+
+
+ .. versionadded:: 2.8.0
+
+ """
+
+ def __init__(self, n_workers: int):
+ self.n_workers = n_workers
+ self._validate()
+
+ def _get_checks(self):
+ """Get dictionary with ``condition: error_message`` pairs that ensure the
+ validity of the backend instance
+
+ Returns
+ -------
+ dict
+ dictionary with ``condition: error_message`` pairs that will get
+ checked during ``_validate()`` run
+ """
+ return {
+ isinstance(self.n_workers, int) and self.n_workers > 0:
+ f"n_workers should be positive integer, got {self.n_workers=}",
+ }
+
+ def _get_warnings(self):
+ """Get dictionary with ``condition: warning_message`` pairs that ensure
+ the good usage of the backend instance
+
+ Returns
+ -------
+ dict
+ dictionary with ``condition: warning_message`` pairs that will get
+ checked during ``_validate()`` run
+ """
+ return dict()
+
+ def _validate(self):
+ """Check correctness (e.g. ``dask`` is installed if using ``backend='dask'``)
+ and good usage (e.g. ``n_workers=1`` if backend is serial) of the backend
+
+ Raises
+ ------
+ ValueError
+ if one of the conditions in :meth:`_get_checks` is ``True``
+ """
+ for check, msg in self._get_checks().items():
+ if not check:
+ raise ValueError(msg)
+ for check, msg in self._get_warnings().items():
+ if not check:
+ warnings.warn(msg)
+
+ def apply(self, func: Callable, computations: list) -> list:
+ """map function `func` to all tasks in the `computations` list
+
+ Main method that will get called when using an instance of
+ ``BackendBase``. It is equivalent to running ``[func(item) for item in
+ computations]`` while using the parallel backend capabilities.
+
+ Parameters
+ ----------
+ func : Callable
+ function to be called on each of the tasks in computations list
+ computations : list
+ computation tasks to apply function to
+
+ Returns
+ -------
+ list
+ list of results of the function
+
+ """
+ raise NotImplementedError
+
+
+class BackendSerial(BackendBase):
+ """A built-in backend that does serial execution of the function, without any
+ parallelization.
+
+ Parameters
+ ----------
+ n_workers : int
+ Is ignored in this class, and if ``n_workers`` > 1, a warning will be
+ given.
+
+
+ .. versionadded:: 2.8.0
+ """
+
+ def _get_warnings(self):
+ """Get dictionary with ``condition: warning_message`` pairs that ensure
+ the good usage of the backend instance. Here, it checks if the number
+ of workers is not 1, otherwise gives warning.
+
+ Returns
+ -------
+ dict
+ dictionary with ``condition: warning_message`` pairs that will get
+ checked during ``_validate()`` run
+ """
+ return {
+ self.n_workers == 1:
+ "n_workers is ignored when executing with backend='serial'"
+ }
+
+ def apply(self, func: Callable, computations: list) -> list:
+ """
+ Serially applies `func` to each task object in ``computations``.
+
+ Parameters
+ ----------
+ func : Callable
+ function to be called on each of the tasks in computations list
+ computations : list
+ computation tasks to apply function to
+
+ Returns
+ -------
+ list
+ list of results of the function
+ """
+ return [func(task) for task in computations]
+
+
+class BackendMultiprocessing(BackendBase):
+ """A built-in backend that executes a given function using the
+ :meth:`multiprocessing.Pool.map ` method.
+
+ Parameters
+ ----------
+ n_workers : int
+ number of processes in :class:`multiprocessing.Pool
+ ` to distribute the workload
+ between. Must be a positive integer.
+
+ Examples
+ --------
+
+ .. code-block:: python
+
+ from MDAnalysis.analysis.backends import BackendMultiprocessing
+ import multiprocessing as mp
+
+ backend_obj = BackendMultiprocessing(n_workers=mp.cpu_count())
+
+
+ .. versionadded:: 2.8.0
+
+ """
+
+ def apply(self, func: Callable, computations: list) -> list:
+ """Applies `func` to each object in ``computations`` using `multiprocessing`'s `Pool.map`.
+
+ Parameters
+ ----------
+ func : Callable
+ function to be called on each of the tasks in computations list
+ computations : list
+ computation tasks to apply function to
+
+ Returns
+ -------
+ list
+ list of results of the function
+ """
+ from multiprocessing import Pool
+
+ with Pool(processes=self.n_workers) as pool:
+ results = pool.map(func, computations)
+ return results
+
+
+class BackendDask(BackendBase):
+ """A built-in backend that executes a given function with *dask*.
+
+ Execution is performed with the :func:`dask.compute` function of
+ :class:`dask.delayed.Delayed` object (created with
+ :func:`dask.delayed.delayed`) with ``scheduler='processes'`` and
+ ``chunksize=1`` (this ensures uniform distribution of tasks among
+ processes). Requires the `dask package `_
+ to be `installed `_.
+
+ Parameters
+ ----------
+ n_workers : int
+ number of processes in to distribute the workload
+ between. Must be a positive integer. Workers are actually
+ :class:`multiprocessing.pool.Pool` processes, but they use a different and
+ more flexible `serialization protocol
+ `_.
+
+ Examples
+ --------
+
+ .. code-block:: python
+
+ from MDAnalysis.analysis.backends import BackendDask
+ import multiprocessing as mp
+
+ backend_obj = BackendDask(n_workers=mp.cpu_count())
+
+
+ .. versionadded:: 2.8.0
+
+ """
+
+ def apply(self, func: Callable, computations: list) -> list:
+ """Applies `func` to each object in ``computations``.
+
+ Parameters
+ ----------
+ func : Callable
+ function to be called on each of the tasks in computations list
+ computations : list
+ computation tasks to apply function to
+
+ Returns
+ -------
+ list
+ list of results of the function
+ """
+ from dask.delayed import delayed
+ import dask
+
+ computations = [delayed(func)(task) for task in computations]
+ results = dask.compute(computations,
+ scheduler="processes",
+ chunksize=1,
+ num_workers=self.n_workers)[0]
+ return results
+
+ def _get_checks(self):
+ """Get dictionary with ``condition: error_message`` pairs that ensure the
+ validity of the backend instance. Here checks if ``dask`` module is
+ installed in the environment.
+
+ Returns
+ -------
+ dict
+ dictionary with ``condition: error_message`` pairs that will get
+ checked during ``_validate()`` run
+ """
+ base_checks = super()._get_checks()
+ checks = {
+ is_installed("dask"):
+ ("module 'dask' is missing. Please install 'dask': "
+ "https://docs.dask.org/en/stable/install.html")
+ }
+ return base_checks | checks
diff --git a/package/MDAnalysis/analysis/base.py b/package/MDAnalysis/analysis/base.py
index 04a0d25b506..47a7eccd137 100644
--- a/package/MDAnalysis/analysis/base.py
+++ b/package/MDAnalysis/analysis/base.py
@@ -103,16 +103,42 @@
contains a step-by-step example for writing analysis tools with
:class:`AnalysisBase`.
-
.. _`Writing your own trajectory analysis`:
https://userguide.mdanalysis.org/stable/examples/analysis/custom_trajectory_analysis.html
+If your analysis is operating independently on each frame, you might consider
+making it **parallelizable** via adding a :meth:`get_supported_backends` method,
+and appropriate aggregation function for each of its results. For example, if
+you have your :meth:`_single_frame` method storing important values under
+:attr:`self.results.timeseries`, you will write:
+
+.. code-block:: python
+
+ class MyAnalysis(AnalysisBase):
+ _analysis_algorithm_is_parallelizable = True
+
+ @classmethod
+ def get_supported_backends(cls):
+ return ('serial', 'multiprocessing', 'dask',)
+
+
+ def _get_aggregator(self):
+ return ResultsGroup(lookup={'timeseries': ResultsGroup.ndarray_vstack})
+
+See :mod:`MDAnalysis.analysis.results` for more on aggregating results.
+
+.. SeeAlso::
+
+ :ref:`parallel-analysis`
+
+
+
Classes
-------
-The :class:`Results` and :class:`AnalysisBase` classes are the essential
-building blocks for almost all MDAnalysis tools in the
+The :class:`MDAnalysis.results.Results` and :class:`AnalysisBase` classes
+are the essential building blocks for almost all MDAnalysis tools in the
:mod:`MDAnalysis.analysis` module. They aim to be easily useable and
extendable.
@@ -121,103 +147,22 @@
tools if only the single-frame analysis function needs to be written.
"""
-from collections import UserDict
import inspect
-import logging
import itertools
+import logging
+import warnings
+from functools import partial
+from typing import Iterable, Union
import numpy as np
-from MDAnalysis import coordinates
-from MDAnalysis.core.groups import AtomGroup
-from MDAnalysis.lib.log import ProgressBar
-
-logger = logging.getLogger(__name__)
-
-
-class Results(UserDict):
- r"""Container object for storing results.
-
- :class:`Results` are dictionaries that provide two ways by which values
- can be accessed: by dictionary key ``results["value_key"]`` or by object
- attribute, ``results.value_key``. :class:`Results` stores all results
- obtained from an analysis after calling :meth:`~AnalysisBase.run()`.
-
- The implementation is similar to the :class:`sklearn.utils.Bunch`
- class in `scikit-learn`_.
-
- .. _`scikit-learn`: https://scikit-learn.org/
-
- Raises
- ------
- AttributeError
- If an assigned attribute has the same name as a default attribute.
-
- ValueError
- If a key is not of type ``str`` and therefore is not able to be
- accessed by attribute.
-
- Examples
- --------
- >>> from MDAnalysis.analysis.base import Results
- >>> results = Results(a=1, b=2)
- >>> results['b']
- 2
- >>> results.b
- 2
- >>> results.a = 3
- >>> results['a']
- 3
- >>> results.c = [1, 2, 3, 4]
- >>> results['c']
- [1, 2, 3, 4]
-
-
- .. versionadded:: 2.0.0
- """
-
- def _validate_key(self, key):
- if key in dir(self):
- raise AttributeError(f"'{key}' is a protected dictionary "
- "attribute")
- elif isinstance(key, str) and not key.isidentifier():
- raise ValueError(f"'{key}' is not a valid attribute")
-
- def __init__(self, *args, **kwargs):
- kwargs = dict(*args, **kwargs)
- if "data" in kwargs.keys():
- raise AttributeError(f"'data' is a protected dictionary attribute")
- self.__dict__["data"] = {}
- self.update(kwargs)
-
- def __setitem__(self, key, item):
- self._validate_key(key)
- super().__setitem__(key, item)
-
- def __setattr__(self, attr, val):
- if attr == 'data':
- super().__setattr__(attr, val)
- else:
- self.__setitem__(attr, val)
-
- def __getattr__(self, attr):
- try:
- return self[attr]
- except KeyError as err:
- raise AttributeError("'Results' object has no "
- f"attribute '{attr}'") from err
-
- def __delattr__(self, attr):
- try:
- del self[attr]
- except KeyError as err:
- raise AttributeError("'Results' object has no "
- f"attribute '{attr}'") from err
+from .. import coordinates
+from ..core.groups import AtomGroup
+from ..lib.log import ProgressBar
- def __getstate__(self):
- return self.data
+from .backends import BackendDask, BackendMultiprocessing, BackendSerial, BackendBase
+from .results import Results, ResultsGroup
- def __setstate__(self, state):
- self.data = state
+logger = logging.getLogger(__name__)
class AnalysisBase(object):
@@ -231,8 +176,8 @@ class AnalysisBase(object):
To define a new Analysis, :class:`AnalysisBase` needs to be subclassed
and :meth:`_single_frame` must be defined. It is also possible to define
:meth:`_prepare` and :meth:`_conclude` for pre- and post-processing.
- All results should be stored as attributes of the :class:`Results`
- container.
+ All results should be stored as attributes of the
+ :class:`MDAnalysis.analysis.results.Results` container.
Parameters
----------
@@ -310,15 +255,140 @@ def _conclude(self):
.. versionchanged:: 2.0.0
Added :attr:`results`
+ .. versionchanged:: 2.8.0
+ Added ability to run analysis in parallel using either a
+ built-in backend (`multiprocessing` or `dask`) or a custom
+ `backends.BackendBase` instance with an implemented `apply` method
+ that is used to run the computations.
"""
+ @classmethod
+ def get_supported_backends(cls):
+ """Tuple with backends supported by the core library for a given class.
+ User can pass either one of these values as ``backend=...`` to
+ :meth:`run()` method, or a custom object that has ``apply`` method
+ (see documentation for :meth:`run()`):
+
+ - 'serial': no parallelization
+ - 'multiprocessing': parallelization using `multiprocessing.Pool`
+ - 'dask': parallelization using `dask.delayed.compute()`. Requires
+ installation of `mdanalysis[dask]`
+
+ If you want to add your own backend to an existing class, pass a
+ :class:`backends.BackendBase` subclass (see its documentation to learn
+ how to implement it properly), and specify ``unsupported_backend=True``.
+
+ Returns
+ -------
+ tuple
+ names of built-in backends that can be used in :meth:`run(backend=...)`
+
+
+ .. versionadded:: 2.8.0
+ """
+ return ("serial",)
+
+ # class authors: override _analysis_algorithm_is_parallelizable
+ # in derived classes and only set to True if you have confirmed
+ # that your algorithm works reliably when parallelized with
+ # the split-apply-combine approach (see docs)
+ _analysis_algorithm_is_parallelizable = False
+
+ @property
+ def parallelizable(self):
+ """Boolean mark showing that a given class can be parallelizable with
+ split-apply-combine procedure. Namely, if we can safely distribute
+ :meth:`_single_frame` to multiple workers and then combine them with a
+ proper :meth:`_conclude` call. If set to ``False``, no backends except
+ for ``serial`` are supported.
+
+ .. note:: If you want to check parallelizability of the whole class, without
+ explicitly creating an instance of the class, see
+ :attr:`_analysis_algorithm_is_parallelizable`. Note that you
+ setting it to other value will break things if the algorithm
+ behind the analysis is not trivially parallelizable.
+
+
+ Returns
+ -------
+ bool
+ if a given ``AnalysisBase`` subclass instance
+ is parallelizable with split-apply-combine, or not
+
+
+ .. versionadded:: 2.8.0
+ """
+ return self._analysis_algorithm_is_parallelizable
+
def __init__(self, trajectory, verbose=False, **kwargs):
self._trajectory = trajectory
self._verbose = verbose
self.results = Results()
- def _setup_frames(self, trajectory, start=None, stop=None, step=None,
- frames=None):
+ def _define_run_frames(self, trajectory,
+ start=None, stop=None, step=None, frames=None
+ ) -> Union[slice, np.ndarray]:
+ """Defines limits for the whole run, as passed by self.run() arguments
+
+ Parameters
+ ----------
+ trajectory : mda.Reader
+ a trajectory Reader
+ start : int, optional
+ start frame of analysis, by default None
+ stop : int, optional
+ stop frame of analysis, by default None
+ step : int, optional
+ number of frames to skip between each analysed frame, by default None
+ frames : array_like, optional
+ array of integers or booleans to slice trajectory; cannot be
+ combined with ``start``, ``stop``, ``step``; by default None
+
+ Returns
+ -------
+ Union[slice, np.ndarray]
+ Appropriate slicer for the trajectory that would give correct iteraction
+ order via trajectory[slicer]
+
+ Raises
+ ------
+ ValueError
+ if *both* `frames` and at least one of ``start``, ``stop``, or ``step``
+ is provided (i.e. set to not ``None`` value).
+
+
+ .. versionadded:: 2.8.0
+ """
+ self._trajectory = trajectory
+ if frames is not None:
+ if not all(opt is None for opt in [start, stop, step]):
+ raise ValueError("start/stop/step cannot be combined with frames")
+ slicer = frames
+ else:
+ start, stop, step = trajectory.check_slice_indices(start, stop, step)
+ slicer = slice(start, stop, step)
+ self.start, self.stop, self.step = start, stop, step
+ return slicer
+
+ def _prepare_sliced_trajectory(self, slicer: Union[slice, np.ndarray]):
+ """Prepares sliced trajectory for use in subsequent parallel computations:
+ namely, assigns self._sliced_trajectory and its appropriate attributes,
+ self.n_frames, self.frames and self.times.
+
+ Parameters
+ ----------
+ slicer : Union[slice, np.ndarray]
+ appropriate slicer for the trajectory
+
+
+ .. versionadded:: 2.8.0
+ """
+ self._sliced_trajectory = self._trajectory[slicer]
+ self.n_frames = len(self._sliced_trajectory)
+ self.frames = np.zeros(self.n_frames, dtype=int)
+ self.times = np.zeros(self.n_frames)
+
+ def _setup_frames(self, trajectory, start=None, stop=None, step=None, frames=None):
"""Pass a Reader object and define the desired iteration pattern
through the trajectory
@@ -334,63 +404,303 @@ def _setup_frames(self, trajectory, start=None, stop=None, step=None,
number of frames to skip between each analysed frame
frames : array_like, optional
array of integers or booleans to slice trajectory; cannot be
- combined with `start`, `stop`, `step`
+ combined with ``start``, ``stop``, ``step``
.. versionadded:: 2.2.0
Raises
------
ValueError
- if *both* `frames` and at least one of `start`, `stop`, or `frames`
- is provided (i.e., set to another value than ``None``)
+ if *both* `frames` and at least one of ``start``, ``stop``, or
+ ``frames`` is provided (i.e., set to another value than ``None``)
.. versionchanged:: 1.0.0
Added .frames and .times arrays as attributes
-
+
.. versionchanged:: 2.2.0
Added ability to iterate through trajectory by passing a list of
frame indices in the `frames` keyword argument
-
+
+ .. versionchanged:: 2.8.0
+ Split function into two: :meth:`_define_run_frames` and
+ :meth:`_prepare_sliced_trajectory`: first one defines the limits
+ for the whole run and is executed once during :meth:`run` in
+ :meth:`_setup_frames`, second one prepares sliced trajectory for
+ each of the workers and gets executed twice: one time in
+ :meth:`_setup_frames` for the whole trajectory, second time in
+ :meth:`_compute` for each of the computation groups.
"""
- self._trajectory = trajectory
- if frames is not None:
- if not all(opt is None for opt in [start, stop, step]):
- raise ValueError("start/stop/step cannot be combined with "
- "frames")
- slicer = frames
- else:
- start, stop, step = trajectory.check_slice_indices(start, stop,
- step)
- slicer = slice(start, stop, step)
- self._sliced_trajectory = trajectory[slicer]
- self.start = start
- self.stop = stop
- self.step = step
- self.n_frames = len(self._sliced_trajectory)
- self.frames = np.zeros(self.n_frames, dtype=int)
- self.times = np.zeros(self.n_frames)
+ slicer = self._define_run_frames(trajectory, start, stop, step, frames)
+ self._prepare_sliced_trajectory(slicer)
def _single_frame(self):
"""Calculate data from a single frame of trajectory
Don't worry about normalising, just deal with a single frame.
+ Attributes accessible during your calculations:
+
+ - ``self._frame_index``: index of the frame in results array
+ - ``self._ts`` -- Timestep instance
+ - ``self._sliced_trajectory`` -- trajectory that you're iterating over
+ - ``self.results`` -- :class:`MDAnalysis.analysis.results.Results` instance
+ holding run results initialized in :meth:`_prepare`.
"""
raise NotImplementedError("Only implemented in child classes")
def _prepare(self):
- """Set things up before the analysis loop begins"""
+ """
+ Set things up before the analysis loop begins.
+
+ Notes
+ -----
+ ``self.results`` is initialized already in :meth:`self.__init__` with an
+ empty instance of :class:`MDAnalysis.analysis.results.Results` object.
+ You can still call your attributes as if they were usual ones,
+ ``Results`` just keeps track of that to be able to run a proper
+ aggregation after a parallel run, if necessary.
+ """
pass # pylint: disable=unnecessary-pass
def _conclude(self):
"""Finalize the results you've gathered.
Called at the end of the :meth:`run` method to finish everything up.
+
+ Notes
+ -----
+ Aggregation of results from individual workers happens in
+ :meth:`self.run()`, so here you have to implement everything as if you
+ had a non-parallel run. If you want to enable proper aggregation for
+ parallel runs for you analysis class, implement ``self._get_aggregator``
+ and check :mod:`MDAnalysis.analysis.results` for how to use it.
"""
pass # pylint: disable=unnecessary-pass
- def run(self, start=None, stop=None, step=None, frames=None,
- verbose=None, *, progressbar_kwargs=None):
+ def _compute(self, indexed_frames: np.ndarray,
+ verbose: bool = None,
+ *, progressbar_kwargs={}) -> "AnalysisBase":
+ """Perform the calculation on a balanced slice of frames
+ that have been setup prior to that using _setup_computation_groups()
+
+ Parameters
+ ----------
+ indexed_frames : np.ndarray
+ np.ndarray of (n, 2) shape, where first column is frame iteration
+ indices and second is frame numbers
+
+ verbose : bool, optional
+ Turn on verbosity
+
+ progressbar_kwargs : dict, optional
+ ProgressBar keywords with custom parameters regarding progress bar
+ position, etc; see :class:`MDAnalysis.lib.log.ProgressBar`
+ for full list.
+
+
+ .. versionadded:: 2.8.0
+ """
+ logger.info("Choosing frames to analyze")
+ # if verbose unchanged, use class default
+ verbose = getattr(self, "_verbose", False) if verbose is None else verbose
+
+ frames = indexed_frames[:, 1]
+
+ logger.info("Starting preparation")
+ self._prepare_sliced_trajectory(slicer=frames)
+ self._prepare()
+ if len(frames) == 0: # if `frames` were empty in `run` or `stop=0`
+ return self
+
+ for idx, ts in enumerate(ProgressBar(
+ self._sliced_trajectory,
+ verbose=verbose,
+ **progressbar_kwargs)):
+ self._frame_index = idx # accessed later by subclasses
+ self._ts = ts
+ self.frames[idx] = ts.frame
+ self.times[idx] = ts.time
+ self._single_frame()
+ logger.info("Finishing up")
+ return self
+
+ def _setup_computation_groups(
+ self, n_parts: int,
+ start: int = None, stop: int = None, step: int = None,
+ frames: Union[slice, np.ndarray] = None
+ ) -> list[np.ndarray]:
+ """
+ Splits the trajectory frames, defined by ``start/stop/step`` or
+ ``frames``, into ``n_parts`` even groups, preserving their indices.
+
+ Parameters
+ ----------
+ n_parts : int
+ number of parts to split the workload into
+ start : int, optional
+ start frame
+ stop : int, optional
+ stop frame
+ step : int, optional
+ step size for analysis (1 means to read every frame)
+ frames : array_like, optional
+ array of integers or booleans to slice trajectory; ``frames`` can
+ only be used *instead* of ``start``, ``stop``, and ``step``. Setting
+ *both* ``frames`` and at least one of ``start``, ``stop``, ``step``
+ to a non-default value will raise a :exc:`ValueError`.
+
+ Raises
+ ------
+ ValueError
+ if *both* ``frames`` and at least one of ``start``, ``stop``, or
+ ``frames`` is provided (i.e., set to another value than ``None``)
+
+ Returns
+ -------
+ computation_groups : list[np.ndarray]
+ list of (n, 2) shaped np.ndarrays with frame indices and numbers
+
+
+ .. versionadded:: 2.8.0
+ """
+ if frames is None:
+ start, stop, step = self._trajectory.check_slice_indices(start, stop, step)
+ used_frames = np.arange(start, stop, step)
+ elif not all(opt is None for opt in [start, stop, step]):
+ raise ValueError("start/stop/step cannot be combined with frames")
+ else:
+ used_frames = frames
+
+ if all(isinstance(obj, bool) for obj in used_frames):
+ arange = np.arange(len(used_frames))
+ used_frames = arange[used_frames]
+
+ # similar to list(enumerate(frames))
+ enumerated_frames = np.vstack([np.arange(len(used_frames)), used_frames]).T
+
+ return np.array_split(enumerated_frames, n_parts)
+
+ def _configure_backend(
+ self,
+ backend: Union[str, BackendBase],
+ n_workers: int,
+ unsupported_backend: bool = False
+ ) -> BackendBase:
+ """Matches a passed backend string value with class attributes
+ :attr:`parallelizable` and :meth:`get_supported_backends()`
+ to check if downstream calculations can be performed.
+
+ Parameters
+ ----------
+ backend : Union[str, BackendBase]
+ backend to be used:
+ - ``str`` is matched to a builtin backend (one of "serial",
+ "multiprocessing" and "dask")
+ - ``BackendBase`` subclass is checked for the presence of
+ an :meth:`apply` method
+ n_workers : int
+ positive integer with number of workers (processes, in case of
+ built-in backends) to split the work between
+ unsupported_backend : bool, optional
+ if you want to run your custom backend on a parallelizable class
+ that has not been tested by developers, by default ``False``
+
+ Returns
+ -------
+ BackendBase
+ instance of a ``BackendBase`` class that will be used for computations
+
+ Raises
+ ------
+ ValueError
+ if :attr:`parallelizable` is set to ``False`` but backend is
+ not ``serial``
+ ValueError
+ if :attr:`parallelizable` is ``True`` and custom backend instance is used
+ without specifying ``unsupported_backend=True``
+ ValueError
+ if your trajectory has associated parallelizable transformations
+ but backend is not serial
+ ValueError
+ if ``n_workers`` was specified twice -- in the run() method and durin
+ ``__init__`` of a custom backend
+ ValueError
+ if your backend object instance doesn't have an ``apply`` method
+
+
+ .. versionadded:: 2.8.0
+ """
+ builtin_backends = {
+ "serial": BackendSerial,
+ "multiprocessing": BackendMultiprocessing,
+ "dask": BackendDask
+ }
+
+ backend_class = builtin_backends.get(backend, backend)
+ supported_backend_classes = [
+ builtin_backends.get(b)
+ for b in self.get_supported_backends()
+ ]
+
+ # check for serial-only classes
+ if not self.parallelizable and backend_class is not BackendSerial:
+ raise ValueError(f"Can not parallelize class {self.__class__}")
+
+ # make sure user enabled 'unsupported_backend=True' for custom classes
+ if (not unsupported_backend and self.parallelizable
+ and backend_class not in supported_backend_classes):
+ raise ValueError((
+ f"Must specify 'unsupported_backend=True'"
+ f"if you want to use a custom {backend_class=} for {self.__class__}"
+ ))
+
+ # check for the presence of parallelizable transformations
+ if backend_class is not BackendSerial and any(
+ not t.parallelizable
+ for t in self._trajectory.transformations):
+ raise ValueError((
+ "Trajectory should not have "
+ "associated unparallelizable transformations"))
+
+ # conclude mapping from string to backend class if it's a builtin backend
+ if isinstance(backend, str):
+ return backend_class(n_workers=n_workers)
+
+ # make sure we haven't specified n_workers twice
+ if (
+ isinstance(backend, BackendBase)
+ and n_workers is not None
+ and hasattr(backend, 'n_workers')
+ and backend.n_workers != n_workers
+ ):
+ raise ValueError((
+ f"n_workers specified twice: in {backend.n_workers=}"
+ f"and in run({n_workers=}). Remove it from run()"
+ ))
+
+ # or pass along an instance of the class itself
+ # after ensuring it has apply method
+ if not isinstance(backend, BackendBase) or not hasattr(backend, "apply"):
+ raise ValueError((
+ f"{backend=} is invalid: should have 'apply' method "
+ "and be instance of MDAnalysis.analysis.backends.BackendBase"
+ ))
+ return backend
+
+ def run(
+ self,
+ start: int = None,
+ stop: int = None,
+ step: int = None,
+ frames: Iterable = None,
+ verbose: bool = None,
+ n_workers: int = None,
+ n_parts: int = None,
+ backend: Union[str, BackendBase] = None,
+ *,
+ unsupported_backend: bool = False,
+ progressbar_kwargs=None,
+ ):
"""Perform the calculation
Parameters
@@ -402,20 +712,42 @@ def run(self, start=None, stop=None, step=None, frames=None,
step : int, optional
number of frames to skip between each analysed frame
frames : array_like, optional
- array of integers or booleans to slice trajectory; `frames` can
- only be used *instead* of `start`, `stop`, and `step`. Setting
- *both* `frames` and at least one of `start`, `stop`, `step` to a
- non-default value will raise a :exc:`ValueError`.
+ array of integers or booleans to slice trajectory; ``frames`` can
+ only be used *instead* of ``start``, ``stop``, and ``step``. Setting
+ *both* ``frames`` and at least one of ``start``, ``stop``, ``step``
+ to a non-default value will raise a :exc:`ValueError`.
.. versionadded:: 2.2.0
-
verbose : bool, optional
Turn on verbosity
progressbar_kwargs : dict, optional
ProgressBar keywords with custom parameters regarding progress bar
- position, etc; see :class:`MDAnalysis.lib.log.ProgressBar` for full
- list.
+ position, etc; see :class:`MDAnalysis.lib.log.ProgressBar`
+ for full list. Available only for ``backend='serial'``
+ backend : Union[str, BackendBase], optional
+ By default, performs calculations in a serial fashion.
+ Otherwise, user can choose a backend: ``str`` is matched to a
+ builtin backend (one of ``serial``, ``multiprocessing`` and
+ ``dask``), or a :class:`MDAnalysis.analysis.results.BackendBase`
+ subclass.
+
+ .. versionadded:: 2.8.0
+ n_workers : int
+ positive integer with number of workers (processes, in case of
+ built-in backends) to split the work between
+
+ .. versionadded:: 2.8.0
+ n_parts : int, optional
+ number of parts to split computations across. Can be more than
+ number of workers.
+
+ .. versionadded:: 2.8.0
+ unsupported_backend : bool, optional
+ if you want to run your custom backend on a parallelizable class
+ that has not been tested by developers, by default False
+
+ .. versionadded:: 2.8.0
.. versionchanged:: 2.2.0
@@ -425,34 +757,88 @@ def run(self, start=None, stop=None, step=None, frames=None,
.. versionchanged:: 2.5.0
Add `progressbar_kwargs` parameter,
allowing to modify description, position etc of tqdm progressbars
- """
- logger.info("Choosing frames to analyze")
- # if verbose unchanged, use class default
- verbose = getattr(self, '_verbose',
- False) if verbose is None else verbose
- self._setup_frames(self._trajectory, start=start, stop=stop,
- step=step, frames=frames)
- logger.info("Starting preparation")
- self._prepare()
- logger.info("Starting analysis loop over %d trajectory frames",
- self.n_frames)
- if progressbar_kwargs is None:
- progressbar_kwargs = {}
+ .. versionchanged:: 2.8.0
+ Introduced ``backend``, ``n_workers``, ``n_parts`` and
+ ``unsupported_backend`` keywords, and refactored the method logic to
+ support parallelizable execution.
+ """
+ # default to serial execution
+ backend = "serial" if backend is None else backend
+
+ progressbar_kwargs = {} if progressbar_kwargs is None else progressbar_kwargs
+ if ((progressbar_kwargs or verbose) and
+ not (backend == "serial" or
+ isinstance(backend, BackendSerial))):
+ raise ValueError("Can not display progressbar with non-serial backend")
+
+ # if number of workers not specified, try getting the number from
+ # the backend instance if possible, or set to 1
+ if n_workers is None:
+ n_workers = (
+ backend.n_workers
+ if isinstance(backend, BackendBase) and hasattr(backend, "n_workers")
+ else 1
+ )
+
+ # set n_parts and check that is has a reasonable value
+ n_parts = n_workers if n_parts is None else n_parts
+
+ # do this as early as possible to check client parameters
+ # before any computations occur
+ executor = self._configure_backend(
+ backend=backend,
+ n_workers=n_workers,
+ unsupported_backend=unsupported_backend)
+ if (
+ hasattr(executor, "n_workers") and n_parts < executor.n_workers
+ ): # using executor's value here for non-default executors
+ warnings.warn((
+ f"Analysis not making use of all workers: "
+ f"{executor.n_workers=} is greater than {n_parts=}"))
+
+ # start preparing the run
+ worker_func = partial(
+ self._compute,
+ progressbar_kwargs=progressbar_kwargs,
+ verbose=verbose)
+ self._setup_frames(
+ trajectory=self._trajectory,
+ start=start, stop=stop, step=step, frames=frames)
+ computation_groups = self._setup_computation_groups(
+ start=start, stop=stop, step=step, frames=frames, n_parts=n_parts
+ )
+
+ # get all results from workers in other processes.
+ # we need `AnalysisBase` classes
+ # since they hold `frames`, `times` and `results` attributes
+ remote_objects: list["AnalysisBase"] = executor.apply(
+ worker_func, computation_groups)
+ self.frames = np.hstack([obj.frames for obj in remote_objects])
+ self.times = np.hstack([obj.times for obj in remote_objects])
+
+ # aggregate results from results obtained in remote workers
+ remote_results = [obj.results for obj in remote_objects]
+ results_aggregator = self._get_aggregator()
+ self.results = results_aggregator.merge(remote_results)
- for i, ts in enumerate(ProgressBar(
- self._sliced_trajectory,
- verbose=verbose,
- **progressbar_kwargs)):
- self._frame_index = i
- self._ts = ts
- self.frames[i] = ts.frame
- self.times[i] = ts.time
- self._single_frame()
- logger.info("Finishing up")
self._conclude()
return self
+ def _get_aggregator(self) -> ResultsGroup:
+ """Returns a default aggregator that takes entire results
+ if there is a single object, and raises ValueError otherwise
+
+ Returns
+ -------
+ ResultsGroup
+ aggregating object
+
+
+ .. versionadded:: 2.8.0
+ """
+ return ResultsGroup(lookup=None)
+
class AnalysisFromFunction(AnalysisBase):
r"""Create an :class:`AnalysisBase` from a function working on AtomGroups
@@ -503,8 +889,18 @@ def rotation_matrix(mobile, ref):
.. versionchanged:: 2.0.0
Former :attr:`results` are now stored as :attr:`results.timeseries`
+
+ .. versionchanged:: 2.8.0
+ Added :meth:`get_supported_backends()`, introducing 'serial', 'multiprocessing'
+ and 'dask' backends.
"""
+ _analysis_algorithm_is_parallelizable = True
+
+ @classmethod
+ def get_supported_backends(cls):
+ return ("serial", "multiprocessing", "dask")
+
def __init__(self, function, trajectory=None, *args, **kwargs):
if (trajectory is not None) and (not isinstance(
trajectory, coordinates.base.ProtoReader)):
@@ -531,9 +927,11 @@ def __init__(self, function, trajectory=None, *args, **kwargs):
def _prepare(self):
self.results.timeseries = []
+ def _get_aggregator(self):
+ return ResultsGroup({"timeseries": ResultsGroup.flatten_sequence})
+
def _single_frame(self):
- self.results.timeseries.append(self.function(*self.args,
- **self.kwargs))
+ self.results.timeseries.append(self.function(*self.args, **self.kwargs))
def _conclude(self):
self.results.frames = self.frames
@@ -594,8 +992,11 @@ def RotationMatrix(mobile, ref):
class WrapperClass(AnalysisFromFunction):
def __init__(self, trajectory=None, *args, **kwargs):
- super(WrapperClass, self).__init__(function, trajectory,
- *args, **kwargs)
+ super(WrapperClass, self).__init__(function, trajectory, *args, **kwargs)
+
+ @classmethod
+ def get_supported_backends(cls):
+ return ("serial", "dask")
return WrapperClass
@@ -633,9 +1034,11 @@ def _filter_baseanalysis_kwargs(function, kwargs):
base_argspec = inspect.getargspec(AnalysisBase.__init__)
n_base_defaults = len(base_argspec.defaults)
- base_kwargs = {name: val
- for name, val in zip(base_argspec.args[-n_base_defaults:],
- base_argspec.defaults)}
+ base_kwargs = {
+ name: val
+ for name, val in zip(base_argspec.args[-n_base_defaults:],
+ base_argspec.defaults)
+ }
try:
# pylint: disable=deprecated-method
@@ -648,7 +1051,8 @@ def _filter_baseanalysis_kwargs(function, kwargs):
if base_kw in argspec.args:
raise ValueError(
"argument name '{}' clashes with AnalysisBase argument."
- "Now allowed are: {}".format(base_kw, base_kwargs.keys()))
+ "Now allowed are: {}".format(base_kw, base_kwargs.keys())
+ )
base_args = {}
for argname, default in base_kwargs.items():
diff --git a/package/MDAnalysis/analysis/results.py b/package/MDAnalysis/analysis/results.py
new file mode 100644
index 00000000000..8aa2062d2bc
--- /dev/null
+++ b/package/MDAnalysis/analysis/results.py
@@ -0,0 +1,321 @@
+"""Analysis results and their aggregation --- :mod:`MDAnalysis.analysis.results`
+================================================================================
+
+Module introduces two classes, :class:`Results` and :class:`ResultsGroup`,
+used for storing and aggregating data in
+:meth:`MDAnalysis.analysis.base.AnalysisBase.run()`, respectively.
+
+
+Classes
+-------
+
+The :class:`Results` class is an extension of a built-in dictionary
+type, that holds all assigned attributes in :attr:`self.data` and
+allows for access either via dict-like syntax, or via class-like syntax:
+
+.. code-block:: python
+
+ from MDAnalysis.analysis.results import Results
+ r = Results()
+ r.array = [1, 2, 3, 4]
+ assert r['array'] == r.array == [1, 2, 3, 4]
+
+
+The :class:`ResultsGroup` can merge multiple :class:`Results` objects.
+It is mainly used by :class:`MDAnalysis.analysis.base.AnalysisBase` class,
+that uses :meth:`ResultsGroup.merge()` method to aggregate results from
+multiple workers, initialized during a parallel run:
+
+.. code-block:: python
+
+ from MDAnalysis.analysis.results import Results, ResultsGroup
+ import numpy as np
+
+ r1, r2 = Results(), Results()
+ r1.masses = [1, 2, 3, 4, 5]
+ r2.masses = [0, 0, 0, 0]
+ r1.vectors = np.arange(10).reshape(5, 2)
+ r2.vectors = np.arange(8).reshape(4, 2)
+
+ group = ResultsGroup(
+ lookup = {
+ 'masses': ResultsGroup.flatten_sequence,
+ 'vectors': ResultsGroup.ndarray_vstack
+ }
+ )
+
+ r = group.merge([r1, r2])
+ assert r.masses == list((*r1.masses, *r2.masses))
+ assert (r.vectors == np.vstack([r1.vectors, r2.vectors])).all()
+"""
+from collections import UserDict
+import numpy as np
+from typing import Callable, Sequence
+
+
+class Results(UserDict):
+ r"""Container object for storing results.
+
+ :class:`Results` are dictionaries that provide two ways by which values
+ can be accessed: by dictionary key ``results["value_key"]`` or by object
+ attribute, ``results.value_key``. :class:`Results` stores all results
+ obtained from an analysis after calling :meth:`~AnalysisBase.run()`.
+
+ The implementation is similar to the :class:`sklearn.utils.Bunch`
+ class in `scikit-learn`_.
+
+ .. _`scikit-learn`: https://scikit-learn.org/
+ .. _`sklearn.utils.Bunch`: https://scikit-learn.org/stable/modules/generated/sklearn.utils.Bunch.html
+
+ Raises
+ ------
+ AttributeError
+ If an assigned attribute has the same name as a default attribute.
+
+ ValueError
+ If a key is not of type ``str`` and therefore is not able to be
+ accessed by attribute.
+
+ Examples
+ --------
+ >>> from MDAnalysis.analysis.base import Results
+ >>> results = Results(a=1, b=2)
+ >>> results['b']
+ 2
+ >>> results.b
+ 2
+ >>> results.a = 3
+ >>> results['a']
+ 3
+ >>> results.c = [1, 2, 3, 4]
+ >>> results['c']
+ [1, 2, 3, 4]
+
+
+ .. versionadded:: 2.0.0
+
+ .. versionchanged:: 2.8.0
+ Moved :class:`Results` to :mod:`MDAnalysis.analysis.results`
+ """
+
+ def _validate_key(self, key):
+ if key in dir(self):
+ raise AttributeError(f"'{key}' is a protected dictionary attribute")
+ elif isinstance(key, str) and not key.isidentifier():
+ raise ValueError(f"'{key}' is not a valid attribute")
+
+ def __init__(self, *args, **kwargs):
+ kwargs = dict(*args, **kwargs)
+ if "data" in kwargs.keys():
+ raise AttributeError(f"'data' is a protected dictionary attribute")
+ self.__dict__["data"] = {}
+ self.update(kwargs)
+
+ def __setitem__(self, key, item):
+ self._validate_key(key)
+ super().__setitem__(key, item)
+
+ def __setattr__(self, attr, val):
+ if attr == "data":
+ super().__setattr__(attr, val)
+ else:
+ self.__setitem__(attr, val)
+
+ def __getattr__(self, attr):
+ try:
+ return self[attr]
+ except KeyError as err:
+ raise AttributeError(f"'Results' object has no attribute '{attr}'") from err
+
+ def __delattr__(self, attr):
+ try:
+ del self[attr]
+ except KeyError as err:
+ raise AttributeError(f"'Results' object has no attribute '{attr}'") from err
+
+ def __getstate__(self):
+ return self.data
+
+ def __setstate__(self, state):
+ self.data = state
+
+
+class ResultsGroup:
+ """
+ Management and aggregation of results stored in :class:`Results` instances.
+
+ A :class:`ResultsGroup` is an optional description for :class:`Result` "dictionaries"
+ that are used in analysis classes based on :class:`AnalysisBase`. For each *key* in a
+ :class:`Result` it describes how multiple pieces of the data held under the key are
+ to be aggregated. This approach is necessary when parts of a trajectory are analyzed
+ independently (e.g., in parallel) and then need to me merged (with :meth:`merge`) to
+ obtain a complete data set.
+
+ Parameters
+ ----------
+ lookup : dict[str, Callable], optional
+ aggregation functions lookup dict, by default None
+
+ Examples
+ --------
+
+ .. code-block:: python
+
+ from MDAnalysis.analysis.results import ResultsGroup, Results
+ group = ResultsGroup(lookup={'mass': ResultsGroup.float_mean})
+ obj1 = Results(mass=1)
+ obj2 = Results(mass=3)
+ assert {'mass': 2.0} == group.merge([obj1, obj2])
+
+
+ .. code-block:: python
+
+ # you can also set `None` for those attributes that you want to skip
+ lookup = {'mass': ResultsGroup.float_mean, 'trajectory': None}
+ group = ResultsGroup(lookup)
+ objects = [Results(mass=1, skip=None), Results(mass=3, skip=object)]
+ assert group.merge(objects, require_all_aggregators=False) == {'mass': 2.0}
+
+ .. versionadded:: 2.8.0
+ """
+
+ def __init__(self, lookup: dict[str, Callable] = None):
+ self._lookup = lookup
+
+ def merge(self, objects: Sequence[Results], require_all_aggregators: bool = True) -> Results:
+ """Merge multiple Results into a single Results instance.
+
+ Merge multiple :class:`Results` instances into a single one, using the
+ `lookup` dictionary to determine the appropriate aggregator functions for
+ each named results attribute. If the resulting object only contains a single
+ element, it just returns it without using any aggregators.
+
+ Parameters
+ ----------
+ objects : Sequence[Results]
+ Multiple :class:`Results` instances with the same data attributes.
+ require_all_aggregators : bool, optional
+ if True, raise an exception when no aggregation function for a
+ particular argument is found. Allows to skip aggregation for the
+ parameters that aren't needed in the final object --
+ see :class:`ResultsGroup`.
+
+ Returns
+ -------
+ Results
+ merged :class:`Results`
+
+ Raises
+ ------
+ ValueError
+ if no aggregation function for a key is found and ``require_all_aggregators=True``
+ """
+ if len(objects) == 1:
+ merged_results = objects[0]
+ return merged_results
+
+ merged_results = Results()
+ for key in objects[0].keys():
+ agg_function = self._lookup.get(key, None)
+ if agg_function is not None:
+ results_of_t = [obj[key] for obj in objects]
+ merged_results[key] = agg_function(results_of_t)
+ elif require_all_aggregators:
+ raise ValueError(f"No aggregation function for {key=}")
+ return merged_results
+
+ @staticmethod
+ def flatten_sequence(arrs: list[list]):
+ """Flatten a list of lists into a list
+
+ Parameters
+ ----------
+ arrs : list[list]
+ list of lists
+
+ Returns
+ -------
+ list
+ flattened list
+ """
+ return [item for sublist in arrs for item in sublist]
+
+ @staticmethod
+ def ndarray_sum(arrs: list[np.ndarray]):
+ """sums an ndarray along ``axis=0``
+
+ Parameters
+ ----------
+ arrs : list[np.ndarray]
+ list of input arrays. Must have the same shape.
+
+ Returns
+ -------
+ np.ndarray
+ sum of input arrays
+ """
+ return np.array(arrs).sum(axis=0)
+
+ @staticmethod
+ def ndarray_mean(arrs: list[np.ndarray]):
+ """calculates mean of input ndarrays along ``axis=0``
+
+ Parameters
+ ----------
+ arrs : list[np.ndarray]
+ list of input arrays. Must have the same shape.
+
+ Returns
+ -------
+ np.ndarray
+ mean of input arrays
+ """
+ return np.array(arrs).mean(axis=0)
+
+ @staticmethod
+ def float_mean(floats: list[float]):
+ """calculates mean of input float values
+
+ Parameters
+ ----------
+ floats : list[float]
+ list of float values
+
+ Returns
+ -------
+ float
+ mean value
+ """
+ return np.array(floats).mean()
+
+ @staticmethod
+ def ndarray_hstack(arrs: list[np.ndarray]):
+ """Performs horizontal stack of input arrays
+
+ Parameters
+ ----------
+ arrs : list[np.ndarray]
+ input numpy arrays
+
+ Returns
+ -------
+ np.ndarray
+ result of stacking
+ """
+ return np.hstack(arrs)
+
+ @staticmethod
+ def ndarray_vstack(arrs: list[np.ndarray]):
+ """Performs vertical stack of input arrays
+
+ Parameters
+ ----------
+ arrs : list[np.ndarray]
+ input numpy arrays
+
+ Returns
+ -------
+ np.ndarray
+ result of stacking
+ """
+ return np.vstack(arrs)
diff --git a/package/MDAnalysis/analysis/rms.py b/package/MDAnalysis/analysis/rms.py
index 36067d76b6d..afb11ed7d2e 100644
--- a/package/MDAnalysis/analysis/rms.py
+++ b/package/MDAnalysis/analysis/rms.py
@@ -166,10 +166,10 @@
import logging
import warnings
-import MDAnalysis.lib.qcprot as qcp
-from MDAnalysis.analysis.base import AnalysisBase
-from MDAnalysis.exceptions import SelectionError, NoDataError
-from MDAnalysis.lib.util import asiterable, iterable, get_weights
+from ..lib import qcprot as qcp
+from ..analysis.base import AnalysisBase, ResultsGroup
+from ..exceptions import SelectionError
+from ..lib.util import asiterable, iterable, get_weights
logger = logging.getLogger('MDAnalysis.analysis.rmsd')
@@ -358,8 +358,17 @@ class RMSD(AnalysisBase):
.. versionchanged:: 2.0.0
:attr:`rmsd` results are now stored in a
:class:`MDAnalysis.analysis.base.Results` instance.
-
+ .. versionchanged:: 2.8.0
+ introduced a :meth:`get_supported_backends` allowing for execution on with
+ ``multiprocessing`` and ``dask`` backends.
"""
+ _analysis_algorithm_is_parallelizable = True
+
+ @classmethod
+ def get_supported_backends(cls):
+ return ('serial', 'multiprocessing', 'dask',)
+
+
def __init__(self, atomgroup, reference=None, select='all',
groupselections=None, weights=None, weights_groupselections=False,
tol_mass=0.1, ref_frame=0, **kwargs):
@@ -670,6 +679,9 @@ def _prepare(self):
self._mobile_coordinates64 = self.mobile_atoms.positions.copy().astype(np.float64)
+ def _get_aggregator(self):
+ return ResultsGroup(lookup={'rmsd': ResultsGroup.ndarray_vstack})
+
def _single_frame(self):
mobile_com = self.mobile_atoms.center(self.weights_select).astype(np.float64)
self._mobile_coordinates64[:] = self.mobile_atoms.positions
@@ -739,6 +751,11 @@ class RMSF(AnalysisBase):
in the array :attr:`RMSF.results.rmsf`.
"""
+
+ @classmethod
+ def get_supported_backends(cls):
+ return ('serial',)
+
def __init__(self, atomgroup, **kwargs):
r"""Parameters
----------
diff --git a/package/MDAnalysis/lib/util.py b/package/MDAnalysis/lib/util.py
index 8600c390e1b..666a8c49279 100644
--- a/package/MDAnalysis/lib/util.py
+++ b/package/MDAnalysis/lib/util.py
@@ -49,6 +49,11 @@
.. autofunction:: format_from_filename_extension
.. autofunction:: guess_format
+Modules and packages
+--------------------
+
+.. autofunction:: is_installed
+
Streams
-------
@@ -147,6 +152,7 @@
.. autofunction:: convert_aa_code
.. autofunction:: parse_residue
.. autofunction:: conv_float
+.. autofunction:: atoi
Class decorators
----------------
@@ -205,6 +211,7 @@
from functools import wraps
import textwrap
import weakref
+import importlib
import itertools
import mmtf
@@ -1333,6 +1340,7 @@ def fixedwidth_bins(delta, xmin, xmax):
dx = 0.5 * (N * _delta - _length) # add half of the excess to each end
return {'Nbins': N, 'delta': _delta, 'min': _xmin - dx, 'max': _xmax + dx}
+
def get_weights(atoms, weights):
"""Check that a `weights` argument is compatible with `atoms`.
@@ -2592,3 +2600,17 @@ def atoi(s: str) -> int:
return int(''.join(itertools.takewhile(str.isdigit, s.strip())))
except ValueError:
return 0
+
+
+def is_installed(modulename: str):
+ """Checks if module is installed
+
+ Parameters
+ ----------
+ modulename : str
+ name of the module to be tested
+
+
+ .. versionadded:: 2.8.0
+ """
+ return importlib.util.find_spec(modulename) is not None
diff --git a/package/doc/sphinx/source/conf.py b/package/doc/sphinx/source/conf.py
index dfc6c606a13..40a096c0275 100644
--- a/package/doc/sphinx/source/conf.py
+++ b/package/doc/sphinx/source/conf.py
@@ -351,4 +351,5 @@ class KeyStyle(UnsrtStyle):
'waterdynamics': ('https://www.mdanalysis.org/waterdynamics/', None),
'pathsimanalysis': ('https://www.mdanalysis.org/PathSimAnalysis/', None),
'mdahole2': ('https://www.mdanalysis.org/mdahole2/', None),
+ 'dask': ('https://docs.dask.org/en/stable/', None),
}
diff --git a/package/doc/sphinx/source/documentation_pages/analysis/backends.rst b/package/doc/sphinx/source/documentation_pages/analysis/backends.rst
new file mode 100644
index 00000000000..36f2a40cf11
--- /dev/null
+++ b/package/doc/sphinx/source/documentation_pages/analysis/backends.rst
@@ -0,0 +1,4 @@
+.. automodule:: MDAnalysis.analysis.backends
+ :members:
+ :private-members:
+
diff --git a/package/doc/sphinx/source/documentation_pages/analysis/base.rst b/package/doc/sphinx/source/documentation_pages/analysis/base.rst
index 4eda92ceac6..2c2d884fc71 100644
--- a/package/doc/sphinx/source/documentation_pages/analysis/base.rst
+++ b/package/doc/sphinx/source/documentation_pages/analysis/base.rst
@@ -1,4 +1,4 @@
.. automodule:: MDAnalysis.analysis.base
:members:
-
+ :private-members:
diff --git a/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst b/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst
new file mode 100644
index 00000000000..91ae05fceca
--- /dev/null
+++ b/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst
@@ -0,0 +1,358 @@
+.. -*- coding: utf-8 -*-
+
+.. _parallel-analysis:
+
+=================
+Parallel analysis
+=================
+
+.. versionadded:: 2.8.0
+ Parallelization of analysis classes was added during Google Summer of Code
+ 2023 by `@marinegor `_ and MDAnalysis GSoC mentors.
+
+This section explains the implementation and background for
+parallelization with the :class:`MDAnalysis.analysis.base.AnalysisBase`, what
+users and developers need to know, when you should use parallelization (almost
+always!), and when you should abstain from doing so (rarely).
+
+
+How to use parallelization
+==========================
+
+In order to use parallelization in a built-in analysis class ``SomeAnalysisClass``,
+simply check which backends are available (see :ref:`backends` for backends
+that are generally available), and then just enable them by providing
+``backend='multiprocessing'`` and ``n_workers=...`` to ``SomeClass.run(...)``
+method:
+
+.. code-block:: python
+
+ u = mda.Universe(...)
+ my_run = SomeClass(trajectory)
+ assert SomeClass.get_supported_backends() == ('serial', 'multiprocessing', 'dask')
+
+ my_run.run(backend='multiprocessing', n_workers=12)
+
+For some classes, such as :class:`MDAnalysis.analysis.rms.RMSF` (in its current implementation),
+split-apply-combine parallelization isn't possible, and running them will be
+impossible with any but the ``serial`` backend.
+
+.. Note::
+
+ Parallelization is getting added to existing analysis classes. Initially,
+ only :class:`MDAnalysis.analysis.rms.RMSD` supports parallel analysis, but
+ we aim to increase support in future releases.
+
+
+How does parallelization work
+=============================
+
+The main idea behind its current implementation is that a trajectory analysis is
+often trivially parallelizable, meaning you can analyze all frames
+independently, and then merge them in a single object. This approach is also
+known as "split-apply-combine", and isn't new to MDAnalysis users, since it was
+first introduced in `PMDA`_ :footcite:p:`Fan2019`.
+Version 2.8.0 of MDAnalysis brings this approach to the main library.
+
+.. _`PMDA`: https://github.com/mdanalysis/pmda
+
+
+split-apply-combine
+-------------------
+
+The following scheme explains the current :meth:`AnalysisBase.run
+` protocol (user-implemented methods
+are highlighted in orange):
+
+.. figure:: /images/AnalysisBase_parallel.png
+
+
+In short, after checking input parameters and configuring the backend,
+:class:`AnalysisBase <` splits all the
+frames into *computation groups* (equally sized sequential groups of frames to
+be processed by each worker). All groups then get **split** between workers of
+a backend configured early, the main instance gets serialized and distributed
+between workers, and then
+:meth:`~MDAnalysis.analysis.base.AnalysisBase._compute()` method gets called
+for all frames of a computation group. Within this method, a user-implemented
+:meth:`~MDAnalysis.analysis.base.AnalysisBase._single_frame` method gets
+**applied** to each frame in a computation group. After that, the main
+instance gets an object that will **combine** all the objects from other
+workers, and all instances get *merged* with an instance of
+:class:`MDAnalysis.analysis.results.ResultsGroup`. Then, a normal
+user-implemented :meth:`~MDAnalysis.analysis.base.AnalysisBase._compute` method
+is called.
+
+Parallelization is fully compatible with existing code and does *not* break
+any existing code pre-2.8.0. The parallelization protocol mimics the
+single-process workflow where possible. Thus, user-implemented methods such as
+:meth:`~MDAnalysis.analysis.base.AnalysisBase._prepare`,
+:meth:`~MDAnalysis.analysis.base.AnalysisBase._single_frame` and
+:meth:`~MDAnalysis.analysis.base.AnalysisBase._conclude` won't need to know
+they are operating on an instance within the main python process, or on a
+remote instance, since the executed code is the same in both cases.
+
+
+Methods in ``AnalysisBase`` for parallelization
+-----------------------------------------------
+
+For developers of new analysis tools
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+If you want to write your own *parallel* analysis class, you have to implement
+:meth:`~MDAnalysis.analysis.base.AnalysisBase._prepare`,
+:meth:`~MDAnalysis.analysis.base.AnalysisBase._single_frame` and
+:meth:`~MDAnalysis.analysis.base.AnalysisBase._conclude`. You also have to
+denote if your analysis can run in parallel by following the steps under
+:ref:`adding-parallelization`.
+
+
+For MDAnalysis developers
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+From a developer point of view, there are a few methods that are important in
+order to understand how parallelization is implemented:
+
+#. :meth:`MDAnalysis.analysis.base.AnalysisBase._define_run_frames`
+#. :meth:`MDAnalysis.analysis.base.AnalysisBase._prepare_sliced_trajectory`
+#. :meth:`MDAnalysis.analysis.base.AnalysisBase._configure_backend`
+#. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_computation_groups`
+#. :meth:`MDAnalysis.analysis.base.AnalysisBase._compute`
+#. :meth:`MDAnalysis.analysis.base.AnalysisBase._get_aggregator`
+
+The first two methods share the functionality of :meth:`_setup_frames`.
+:meth:`_define_run_frames` is run once during analysis, as it checks that input
+parameters `start`, `stop`, `step` or `frames` are consistent with the given
+trajectory and prepares the ``slicer`` object that defines the iteration
+pattern through the trajectory. :meth:`_prepare_sliced_trajectory` assigns to
+the :attr:`self._sliced_trajectory` attribute, computes the number of frames in
+it, and fills the :attr:`self.frames` and :attr:`self.times` arrays. In case
+the computation will be later split between other processes, this method will
+be called again on each of the computation groups.
+
+The method :meth:`_configure_backend` performs basic health checks for a given
+analysis class -- namely, it compares a given backend (if it's a :class:`str`
+instance, such as ``'multiprocessing'``) with the list of builtin backends (and
+also the backends implemented for a given analysis subclass), and configures a
+:class:`MDAnalysis.analysis.backends.BackendBase` instance accordingly. If the
+user decides to provide a custom backend (any subclass of
+:class:`MDAnalysis.analysis.backends.BackendBase`, or anything with an
+:meth:`apply` method), it ensures that the number of workers wasn't specified
+twice (during backend initialization and in :meth:`run` arguments).
+
+After a backend is configured, :meth:`_setup_computation_groups` splits the
+frames prepared earlier in :attr:`self._prepare_sliced_trajectory` into a
+number of groups, by default equal to the number of workers.
+
+In the :meth:`_compute` method, frames get initialized again with
+:meth:`_prepare_sliced_trajectory`, and attributes necessary for a specific
+analysis get initialized with the :meth:`_prepare` method. Then the function
+iterates over :attr:`self._sliced_trajectory`, assigning
+:attr:`self._frame_index` and :attr:`self._ts` as frame index (within a
+computation group) and timestamp, and also setting respective
+:attr:`self.frames` and :attr:`self.times` array values.
+
+After :meth:`_compute` has finished, the main analysis instance calls the
+:meth:`_get_aggregator` method, which merges the :attr:`self.results`
+attributes from other processes into a single
+:class:`MDAnalysis.analysis.results.Results` instance, making it look for the
+subsequent :meth:`_conclude` method as if the run was performed in a serial
+fashion, without parallelization.
+
+
+Helper classes for parallelization
+==================================
+
+``ResultsGroup``
+----------------
+
+:class:`MDAnalysis.analysis.results.ResultsGroup` extends the functionality of
+the :class:`MDAnalysis.analysis.results.Results` class. Since the ``Results``
+class is basically a dictionary that also keeps track of assigned attributes, it
+is possible to iterate over all these attributes later. ``ResultsGroup`` does
+exactly that: given a list of the ``Results`` objects with the same attributes,
+it applies a specific aggregation function to every attribute, and stores it as
+a same attribute of the returned object:
+
+.. code-block:: python
+
+ from MDAnalysis.analysis.results import ResultsGroup, Results
+ group = ResultsGroup(lookup={'mass': ResultsGroup.float_mean})
+ obj1 = Results(mass=1)
+ obj2 = Results(mass=3)
+ assert group.merge([obj1, obj2]) == Results(mass=2.0)
+
+
+``BackendBase``
+---------------
+
+:class:`MDAnalysis.analysis.backends.BackendBase` holds all backend attributes,
+and also implements an :meth:`MDAnalysis.analysis.backends.BackendBase.apply`
+method, applying a given function to a list of its parameters, but in a parallel
+fashion. Although in ``AnalysisBase`` it is used to apply a ``_compute``
+function, in principle it can be used to any arbitrary function and arguments,
+given they're serializable.
+
+
+When to use parallelization? (Known limitations)
+================================================
+
+For now, the syntax for running parallel analysis is explicit, meaning by
+default the ``serial`` version will be run, and the parallelization won't be
+enabled by default. Although we expect the parallelization to be useful in most
+cases, there are some known caveats from the inital benchmarks.
+
+Fast ``_single_frame`` compared to reading from disk
+----------------------------------------------------
+
+In all cases, parallelization will not be useful only when frames are being
+processed faster than being read from the disk, otherwise reading is the
+bottleneck here. Hence, you'll benefit from parallelization only if you have
+relatively much compute per frame, or a fast drive, as illustrated below:
+
+.. figure:: /images/parallelization_time.png
+
+In other words, if you have *fast* analysis (say,
+:class:`MDAnalysis.analysis.rms.RMSD`) **and** a slow HDD drive, you are likely
+to not get any benefits from parallelization. Otherwise, you should be fine.
+
+Serialization issues
+--------------------
+
+For built-in analysis classes, the default serialization with both
+:mod:`multiprocessing` and :mod:`dask` is known to work. If you're using some custom
+analysis class that e.g. stores a non-serializable object in one of its
+attributes, you might get a serialization error (:exc:`PicklingError` if you're
+using a ``multiprocessing`` backend). If you want to get around that, we suggest
+trying ``backend='dask'`` (it uses ``dask`` serialization engine instead of
+:mod:`pickle`).
+
+Out of memory issues
+--------------------
+
+If you have large memory footprint of each worker, you can run into
+out-of-memory errors (i.e. your server freezes when executing a run). In this
+case we suggest decreasing the number of workers from all available CPUs (that
+you can get with :func:`multiprocessing.cpu_count`) to a smaller number.
+
+Progress bar is missing
+-----------------------
+
+It is yet not possible to get a progress bar running with any parallel backend.
+If you want an ETA of your analysis, we suggest running it in ``serial`` mode
+for the first 10-100 frames with ``verbose=True``, and then running it with
+multiple workers. Processing time scales almost linearly, so you can get your
+ETA by dividing ``serial`` ETA by the number of workers.
+
+
+.. _adding-parallelization:
+
+Adding parallelization to your own analysis class
+=================================================
+
+If you want to add parallelization to your own analysis class, first make sure
+your algorithm allows you to do that, i.e. you can process each frame independently.
+Then it's rather simple -- let's look at the actual code that added
+parallelization to the :class:`MDAnalysis.analysis.rms.RMSD`:
+
+.. code-block:: python
+
+ from MDAnalysis.analysis.base import AnalysisBase
+ from MDAnalysis.analysis.results import ResultsGroup
+
+ class RMSD(BackendBase):
+ @classmethod
+ def get_supported_backends(cls):
+ return ('serial', 'multiprocessing', 'dask',)
+
+ _analysis_algorithm_is_parallelizable = True
+
+ def _get_aggregator(self):
+ return ResultsGroup(lookup={'rmsd': ResultsGroup.ndarray_vstack})
+
+
+That's it! The first two methods are boilerplate --
+:meth:`get_supported_backends` returns a tuple with built-in backends that will
+work for your class (if there are no serialization issues, it should be all
+three), and ``_is_parallelizable`` is ``True`` (which is set to ``False`` in
+``AnalysisBase``, hence we have to re-define it), and :meth:`_get_aggregator`
+will be used as described earlier. Note that :mod:`MDAnalysis.analysis.results`
+also provides a few convenient functions (defined as class methods of
+:class:`~MDAnalysis.analysis.results.ResultsGroup`) for results aggregation:
+
+#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.flatten_sequence`
+#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.ndarray_sum`
+#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.ndarray_mean`
+#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.float_mean`
+#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.ndarray_hstack`
+#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.ndarray_vstack`
+
+
+So you'll likely find appropriate functions for basic aggregation there.
+
+Writing custom backends
+=======================
+
+In order to write your custom backend (e.g. using :mod:`dask.distributed`), inherit
+from the :class:`MDAnalysis.analysis.backends.BackendBase` and (re)-implement
+:meth:`__init__` and :meth:`apply` methods. Optionally, you can implement methods for
+validation of correct backend initialization -- :meth:`_get_checks` and
+:meth:`_get_warnings`, that would raise an exception or give a warning, respectively,
+when a new class instance is created:
+
+#. :meth:`MDAnalysis.analysis.backends._get_checks`
+#. :meth:`MDAnalysis.analysis.backends._get_warnings`
+
+.. code-block:: python
+
+ from MDAnalysis.analysis.backends import BackendBase
+ class ThreadsBackend(BackendBase):
+ def __init__(self, n_workers: int, starting_message: str = "Useless backend"):
+ self.n_workers = n_workers
+ self.starting_message = starting_message
+ self._validate()
+
+ def _get_warnings(self):
+ return {True: 'warning: this backend is useless'}
+
+ def _get_checks(self):
+ return {isinstance(self.n_workers, int), 'error: self.n_workers is not an integer'}
+
+ def apply(self, func, computations):
+ from multiprocessing.dummy import Pool
+
+ with Pool(processes=self.n_workers) as pool:
+ print(self.starting_message)
+ results = pool.map(func, computations)
+ return results
+
+
+In order to use a custom backend with another analysis class that does not
+explicitly support it, you must *explicitly state* that you're about to use an
+unsupported_backend by passing the keyword argument
+``unsupported_backend=True``:
+
+.. code-block:: python
+
+ from MDAnalysis.analysis.rms import RMSD
+ R = RMSD(...) # setup the run
+ n_workers = 2
+ backend = ThreadsBackend(n_workers=n_workers)
+ R.run(backend=backend, unsupported_backend=True)
+
+In this way, you will override the check for supported backends.
+
+.. Warning::
+
+ When you use ``unsupported_backend=True`` you should make sure that you get
+ the same results as when using a supported backend for which the analysis
+ class was tested.
+
+ Before reporting a problem with an analysis class, make sure you tested it
+ with a supported backend. When reporting *always mention if you used*
+ ``unsupported_backend=True``.
+
+
+.. rubric:: References
+.. footbibliography::
+
diff --git a/package/doc/sphinx/source/documentation_pages/analysis/results.rst b/package/doc/sphinx/source/documentation_pages/analysis/results.rst
new file mode 100644
index 00000000000..22fc9fd81f3
--- /dev/null
+++ b/package/doc/sphinx/source/documentation_pages/analysis/results.rst
@@ -0,0 +1,4 @@
+.. automodule:: MDAnalysis.analysis.results
+ :members:
+
+
diff --git a/package/doc/sphinx/source/documentation_pages/analysis_modules.rst b/package/doc/sphinx/source/documentation_pages/analysis_modules.rst
index ee7f6568758..ec484b0382d 100644
--- a/package/doc/sphinx/source/documentation_pages/analysis_modules.rst
+++ b/package/doc/sphinx/source/documentation_pages/analysis_modules.rst
@@ -1,23 +1,42 @@
.. Contains the formatted docstrings from the analysis modules located in 'mdanalysis/MDAnalysis/analysis', although in some cases the documentation imports functions and docstrings from other files which are also curated to reStructuredText markup.
+.. module:: MDAnalysis.analysis
+
****************
Analysis modules
****************
The :mod:`MDAnalysis.analysis` module contains code to carry out specific
-analysis functionality for MD trajectories.
-It is based on the core functionality (i.e. trajectory
-I/O, selections etc). The analysis modules can be used as examples for how to
-use MDAnalysis but also as working code for research projects; typically all
-contributed code has been used by the authors in their own work.
-An analysis using the available modules
-usually follows the same structure
+analysis functionality for MD trajectories. It is based on the core
+functionality (i.e. trajectory I/O, selections etc). The analysis modules can
+be used as examples for how to use MDAnalysis but also as working code for
+research projects; typically all contributed code has been used by the authors
+in their own work.
+
+Getting started with analysis
+=============================
+
+.. SeeAlso::
+
+ The `User Guide: Analysis`_ contains extensive documentation of the analysis
+ capabilities with user-friendly examples.
+
+.. _`User Guide: Analysis`:
+ https://userguide.mdanalysis.org/stable/examples/analysis/README.html
+
+Using the analysis classes
+--------------------------
+
+Most analysis tools in MDAnalysis are written as a single class. An analysis
+usually follows the same pattern:
#. Import the desired module, since analysis modules are not imported
by default.
#. Initialize the analysis class instance from the previously imported module.
-#. Run the analysis, optionally for specific trajectory slices.
-#. Access the analysis from the :attr:`results` attribute
+#. Run the analysis with the :meth:`~MDAnalysis.analysis.base.AnalysisBase.run`
+ method, optionally for specific trajectory slices.
+#. Access the analysis from the
+ :attr:`~MDAnalysis.analysis.base.AnalysisBase.results` attribute
.. code-block:: python
@@ -31,28 +50,81 @@ usually follows the same structure
Please see the individual module documentation for any specific caveats
and also read and cite the reference papers associated with these algorithms.
-.. rubric:: Additional dependencies
+
+Using parallelization for built-in analysis runs
+------------------------------------------------
+
+.. versionadded:: 2.8.0
+
+:class:`~MDAnalysis.analysis.base.AnalysisBase` subclasses can run on a backend
+that supports parallelization (see :mod:`MDAnalysis.analysis.backends`). All
+analysis runs use ``backend='serial'`` by default, i.e., they do not use
+parallelization by default, which has been standard before release 2.8.0 of
+MDAnalysis.
+
+Without any dependencies, only one backend is supported -- built-in
+:mod:`multiprocessing`, that processes parts of a trajectory running separate
+*processes*, i.e. utilizing multi-core processors properly.
+
+.. Note::
+
+ For now, parallelization has only been added to
+ :class:`MDAnalysis.analysis.rms.RMSD`, but by release 3.0 version it will be
+ introduced to all subclasses that can support it.
+
+In order to use that feature, simply add ``backend='multiprocessing'`` to your
+run, and supply it with proper ``n_workers`` (use ``multiprocessing.cpu_count()``
+for maximum available on your machine):
+
+.. code-block:: python
+
+ import multiprocessing
+ import MDAnalysis as mda
+ from MDAnalysisTests.datafiles import PSF, DCD
+ from MDAnalysis.analysis.rms import RMSD
+ from MDAnalysis.analysis.align import AverageStructure
+
+ # initialize the universe
+ u = mda.Universe(PSF, DCD)
+
+ # calculate average structure for reference
+ avg = AverageStructure(mobile=u).run()
+ ref = avg.results.universe
+
+ # initialize RMSD run
+ rmsd = RMSD(u, ref, select='backbone')
+ rmsd.run(backend='multiprocessing', n_workers=multiprocessing.cpu_count())
+
+For now, you have to be explicit and specify both ``backend`` and ``n_workers``,
+since the feature is new and there are no good defaults for it. For example,
+if you specify a too big `n_workers`, and your trajectory frames are big,
+you might get and out-of-memory error when executing your run.
+
+You can also implement your own backends -- see :mod:`MDAnalysis.analysis.backends`.
+
+
+Additional dependencies
+-----------------------
Some of the modules in :mod:`MDAnalysis.analysis` require additional Python
packages to enable full functionality. For example,
:mod:`MDAnalysis.analysis.encore` provides more options if `scikit-learn`_ is
-installed. If you installed MDAnalysis with
-:program:`pip` (see :ref:`installation-instructions`)
-these packages are *not automatically installed*.
-Although, one can add the ``[analysis]`` tag to the
-:program:`pip` command to force their installation. If you installed
-MDAnalysis with :program:`conda` then a
-*full set of dependencies* is automatically installed.
+installed. If you installed MDAnalysis with :program:`pip` (see
+:ref:`installation-instructions`) these packages are *not automatically
+installed* although one can add the ``[analysis]`` tag to the :program:`pip`
+command to force their installation. If you installed MDAnalysis with
+:program:`conda` then a *full set of dependencies* is automatically installed.
Other modules require external programs. For instance, the
-:mod:`MDAnalysis.analysis.hole2.hole` module requires an installation of the
-HOLE_ suite of programs. You will need to install these external dependencies
-by following their installation instructions before you can use the
-corresponding MDAnalysis module.
+:mod:`MDAnalysis.analysis.hole2` module requires an installation of the HOLE_
+suite of programs. You will need to install these external dependencies by
+following their installation instructions before you can use the corresponding
+MDAnalysis module.
.. _scikit-learn: http://scikit-learn.org/
.. _HOLE: http://www.holeprogram.org/
+
Building blocks for Analysis
============================
@@ -64,6 +136,9 @@ To build your own analysis class start by reading the documentation.
:maxdepth: 1
analysis/base
+ analysis/backends
+ analysis/results
+ analysis/parallelization
Distances and contacts
======================
diff --git a/package/doc/sphinx/source/images/AnalysisBase_parallel.png b/package/doc/sphinx/source/images/AnalysisBase_parallel.png
new file mode 100644
index 00000000000..960e7cc257b
Binary files /dev/null and b/package/doc/sphinx/source/images/AnalysisBase_parallel.png differ
diff --git a/package/doc/sphinx/source/images/parallelization_time.png b/package/doc/sphinx/source/images/parallelization_time.png
new file mode 100644
index 00000000000..09e584555ac
Binary files /dev/null and b/package/doc/sphinx/source/images/parallelization_time.png differ
diff --git a/package/doc/sphinx/source/references.bib b/package/doc/sphinx/source/references.bib
index 5e2167c9867..8fe33a1b64d 100644
--- a/package/doc/sphinx/source/references.bib
+++ b/package/doc/sphinx/source/references.bib
@@ -795,3 +795,15 @@ @article{Linke2018
pages = {5630--5639},
doi = {10.1021/acs.jpcb.7b11988}
}
+
+@inproceedings{Fan2019,
+ title = {{PMDA} - {P}arallel {M}olecular {D}ynamics {A}nalysis},
+ author = {Shujie Fan and Max Linke and Ioannis Paraskevakos and Richard J. Gowers and Michael Gecht and Oliver Beckstein},
+ year = {2019},
+ booktitle = {{P}roceedings of the 18th {P}ython in {S}cience {C}onference},
+ editor = {{C}hris {C}alloway and {D}avid {L}ippa and {D}illon {N}iederhut and {D}avid {S}hupe},
+ organization = {SciPy},
+ address = {Austin, TX},
+ pages = {134 - 142},
+ doi = {10.25080/Majora-7ddc1dd1-013}
+}
diff --git a/package/pyproject.toml b/package/pyproject.toml
index 1880d88fd0d..05b0eb3ba16 100644
--- a/package/pyproject.toml
+++ b/package/pyproject.toml
@@ -94,6 +94,9 @@ doc = [
"pybtex",
"pybtex-docutils",
]
+parallel = [
+ "dask",
+]
[project.urls]
Homepage = "https://www.mdanalysis.org"
diff --git a/testsuite/MDAnalysisTests/analysis/conftest.py b/testsuite/MDAnalysisTests/analysis/conftest.py
new file mode 100644
index 00000000000..55bae7e6bd8
--- /dev/null
+++ b/testsuite/MDAnalysisTests/analysis/conftest.py
@@ -0,0 +1,89 @@
+import pytest
+
+from MDAnalysis.analysis.base import AnalysisBase, AnalysisFromFunction
+from MDAnalysisTests.analysis.test_base import (
+ FrameAnalysis,
+ IncompleteAnalysis,
+ OldAPIAnalysis,
+)
+from MDAnalysis.analysis.rms import RMSD, RMSF
+from MDAnalysis.lib.util import is_installed
+
+
+def params_for_cls(cls, exclude: list[str] = None):
+ """
+ This part contains fixtures for simultaneous testing
+ of all available (=installed & supported) backends
+ for analysis subclasses.
+
+ If for some reason you want to limit backends,
+ simply pass "exclude: list[str]" to the function
+ that parametrizes fixture.
+
+ Parameters
+ ----------
+ exclude : list[str], optional
+ list of backends to exclude from parametrization, by default None
+
+ Returns
+ -------
+ dict
+ dictionary with all tested keyword combinations for the run
+ """
+ exclude = [] if exclude is None else exclude
+ possible_backends = cls.get_supported_backends()
+ installed_backends = [
+ b for b in possible_backends if is_installed(b) and b not in exclude
+ ]
+
+ params = [
+ pytest.param({
+ "backend": backend,
+ "n_workers": nproc
+ }, ) for backend in installed_backends for nproc in (2, )
+ if backend != "serial"
+ ]
+ params.extend([{"backend": "serial"}])
+ return params
+
+
+@pytest.fixture(scope='module', params=params_for_cls(FrameAnalysis))
+def client_FrameAnalysis(request):
+ return request.param
+
+
+@pytest.fixture(scope='module', params=params_for_cls(AnalysisBase))
+def client_AnalysisBase(request):
+ return request.param
+
+
+@pytest.fixture(scope='module', params=params_for_cls(AnalysisFromFunction))
+def client_AnalysisFromFunction(request):
+ return request.param
+
+
+@pytest.fixture(scope='module',
+ params=params_for_cls(AnalysisFromFunction,
+ exclude=['multiprocessing']))
+def client_AnalysisFromFunctionAnalysisClass(request):
+ return request.param
+
+
+@pytest.fixture(scope='module', params=params_for_cls(IncompleteAnalysis))
+def client_IncompleteAnalysis(request):
+ return request.param
+
+
+@pytest.fixture(scope='module', params=params_for_cls(OldAPIAnalysis))
+def client_OldAPIAnalysis(request):
+ return request.param
+
+
+@pytest.fixture(scope='module', params=params_for_cls(RMSD))
+def client_RMSD(request):
+ return request.param
+
+
+@pytest.fixture(scope='module', params=params_for_cls(RMSF))
+def client_RMSF(request):
+ return request.param
diff --git a/testsuite/MDAnalysisTests/analysis/test_backends.py b/testsuite/MDAnalysisTests/analysis/test_backends.py
new file mode 100644
index 00000000000..a4c105e082a
--- /dev/null
+++ b/testsuite/MDAnalysisTests/analysis/test_backends.py
@@ -0,0 +1,72 @@
+import pytest
+from MDAnalysis.analysis import backends
+from MDAnalysis.lib.util import is_installed
+
+
+def square(x: int):
+ return x**2
+
+
+def noop(x):
+ return x
+
+
+def upper(s):
+ return s.upper()
+
+
+class Test_Backends:
+
+ @pytest.mark.parametrize(
+ "backend_cls,n_workers",
+ [
+ (backends.BackendBase, -1),
+ (backends.BackendSerial, None),
+ (backends.BackendMultiprocessing, "string"),
+ (backends.BackendDask, ()),
+ ],
+ )
+ def test_fails_incorrect_n_workers(self, backend_cls, n_workers):
+ with pytest.raises(ValueError):
+ _ = backend_cls(n_workers=n_workers)
+
+ @pytest.mark.parametrize(
+ "func,iterable,answer",
+ [
+ (square, (1, 2, 3), [1, 4, 9]),
+ (square, (), []),
+ (noop, list(range(10)), list(range(10))),
+ (upper, "asdf", list("ASDF")),
+ ],
+ )
+ def test_all_backends_give_correct_results(self, func, iterable, answer):
+ backend_instances = [
+ backends.BackendMultiprocessing(n_workers=2),
+ backends.BackendSerial(n_workers=1),
+ ]
+ if is_installed("dask"):
+ backend_instances.append(backends.BackendDask(n_workers=2))
+
+ backends_dict = {b: b.apply(func, iterable) for b in backend_instances}
+ for answ in backends_dict.values():
+ assert answ == answer
+
+ @pytest.mark.parametrize("backend_cls,params,warning_message", [
+ (backends.BackendSerial, {
+ 'n_workers': 5
+ }, "n_workers is ignored when executing with backend='serial'"),
+ ])
+ def test_get_warnings(self, backend_cls, params, warning_message):
+ with pytest.warns(UserWarning, match=warning_message):
+ backend_cls(**params)
+
+ @pytest.mark.parametrize("backend_cls,params,error_message", [
+ pytest.param(backends.BackendDask, {'n_workers': 2},
+ ("module 'dask' is missing. Please install 'dask': "
+ "https://docs.dask.org/en/stable/install.html"),
+ marks=pytest.mark.skipif(is_installed('dask'),
+ reason='dask is installed'))
+ ])
+ def test_get_errors(self, backend_cls, params, error_message):
+ with pytest.raises(ValueError, match=error_message):
+ backend_cls(**params)
diff --git a/testsuite/MDAnalysisTests/analysis/test_base.py b/testsuite/MDAnalysisTests/analysis/test_base.py
index 84cd4c5b9c2..6e6b7921960 100644
--- a/testsuite/MDAnalysisTests/analysis/test_base.py
+++ b/testsuite/MDAnalysisTests/analysis/test_base.py
@@ -30,134 +30,37 @@
from numpy.testing import assert_equal, assert_allclose
import MDAnalysis as mda
-from MDAnalysis.analysis import base
-
-from MDAnalysisTests.datafiles import PSF, DCD, TPR, XTC
+import numpy as np
+import pytest
+from MDAnalysis.analysis import base, backends
+from MDAnalysisTests.datafiles import DCD, PSF, TPR, XTC
from MDAnalysisTests.util import no_deprecated_call
-
-
-class Test_Results:
-
- @pytest.fixture
- def results(self):
- return base.Results(a=1, b=2)
-
- def test_get(self, results):
- assert results.a == results["a"] == 1
-
- def test_no_attr(self, results):
- msg = "'Results' object has no attribute 'c'"
- with pytest.raises(AttributeError, match=msg):
- results.c
-
- def test_set_attr(self, results):
- value = [1, 2, 3, 4]
- results.c = value
- assert results.c is results["c"] is value
-
- def test_set_key(self, results):
- value = [1, 2, 3, 4]
- results["c"] = value
- assert results.c is results["c"] is value
-
- @pytest.mark.parametrize('key', dir(UserDict) + ["data"])
- def test_existing_dict_attr(self, results, key):
- msg = f"'{key}' is a protected dictionary attribute"
- with pytest.raises(AttributeError, match=msg):
- results[key] = None
-
- @pytest.mark.parametrize('key', dir(UserDict) + ["data"])
- def test_wrong_init_type(self, key):
- msg = f"'{key}' is a protected dictionary attribute"
- with pytest.raises(AttributeError, match=msg):
- base.Results(**{key: None})
-
- @pytest.mark.parametrize('key', ("0123", "0j", "1.1", "{}", "a b"))
- def test_weird_key(self, results, key):
- msg = f"'{key}' is not a valid attribute"
- with pytest.raises(ValueError, match=msg):
- results[key] = None
-
- def test_setattr_modify_item(self, results):
- mylist = [1, 2]
- mylist2 = [3, 4]
- results.myattr = mylist
- assert results.myattr is mylist
- results["myattr"] = mylist2
- assert results.myattr is mylist2
- mylist2.pop(0)
- assert len(results.myattr) == 1
- assert results.myattr is mylist2
-
- def test_setitem_modify_item(self, results):
- mylist = [1, 2]
- mylist2 = [3, 4]
- results["myattr"] = mylist
- assert results.myattr is mylist
- results.myattr = mylist2
- assert results.myattr is mylist2
- mylist2.pop(0)
- assert len(results["myattr"]) == 1
- assert results["myattr"] is mylist2
-
- def test_delattr(self, results):
- assert hasattr(results, "a")
- delattr(results, "a")
- assert not hasattr(results, "a")
-
- def test_missing_delattr(self, results):
- assert not hasattr(results, "d")
- msg = "'Results' object has no attribute 'd'"
- with pytest.raises(AttributeError, match=msg):
- delattr(results, "d")
-
- def test_pop(self, results):
- assert hasattr(results, "a")
- results.pop("a")
- assert not hasattr(results, "a")
-
- def test_update(self, results):
- assert not hasattr(results, "spudda")
- results.update({"spudda": "fett"})
- assert results.spudda == "fett"
-
- def test_update_data_fail(self, results):
- msg = f"'data' is a protected dictionary attribute"
- with pytest.raises(AttributeError, match=msg):
- results.update({"data": 0})
-
- def test_pickle(self, results):
- results_p = pickle.dumps(results)
- results_new = pickle.loads(results_p)
-
- @pytest.mark.parametrize("args, kwargs, length", [
- (({"darth": "tater"},), {}, 1),
- ([], {"darth": "tater"}, 1),
- (({"darth": "tater"},), {"yam": "solo"}, 2),
- (({"darth": "tater"},), {"darth": "vader"}, 1),
- ])
- def test_initialize_arguments(self, args, kwargs, length):
- results = base.Results(*args, **kwargs)
- ref = dict(*args, **kwargs)
- assert ref == results
- assert len(results) == length
-
- def test_different_instances(self, results):
- new_results = base.Results(darth="tater")
- assert new_results.data is not results.data
+from numpy.testing import assert_almost_equal, assert_equal
class FrameAnalysis(base.AnalysisBase):
"""Just grabs frame numbers of frames it goes over"""
+ @classmethod
+ def get_supported_backends(cls): return ('serial', 'dask', 'multiprocessing')
+
+ _analysis_algorithm_is_parallelizable = True
+
def __init__(self, reader, **kwargs):
super(FrameAnalysis, self).__init__(reader, **kwargs)
self.traj = reader
- self.found_frames = []
+
+ def _prepare(self):
+ self.results.found_frames = []
def _single_frame(self):
- self.found_frames.append(self._ts.frame)
+ self.results.found_frames.append(self._ts.frame)
+
+ def _conclude(self):
+ self.found_frames = list(self.results.found_frames)
+ def _get_aggregator(self):
+ return base.ResultsGroup({'found_frames': base.ResultsGroup.ndarray_hstack})
class IncompleteAnalysis(base.AnalysisBase):
def __init__(self, reader, **kwargs):
@@ -173,6 +76,9 @@ def __init__(self, reader, **kwargs):
def _single_frame(self):
pass
+ def _prepare(self):
+ self.results = base.Results()
+
@pytest.fixture(scope='module')
def u():
@@ -187,6 +93,110 @@ def u_xtc():
FRAMES_ERR = 'AnalysisBase.frames is incorrect'
TIMES_ERR = 'AnalysisBase.times is incorrect'
+class Parallelizable(base.AnalysisBase):
+ _analysis_algorithm_is_parallelizable = True
+ @classmethod
+ def get_supported_backends(cls): return ('multiprocessing', 'dask')
+ def _single_frame(self): pass
+
+class SerialOnly(base.AnalysisBase):
+ def _single_frame(self): pass
+
+class ParallelizableWithDaskOnly(base.AnalysisBase):
+ _analysis_algorithm_is_parallelizable = True
+ @classmethod
+ def get_supported_backends(cls): return ('dask',)
+ def _single_frame(self): pass
+
+class CustomSerialBackend(backends.BackendBase):
+ def apply(self, func, computations):
+ return [func(task) for task in computations]
+
+class ManyWorkersBackend(backends.BackendBase):
+ def apply(self, func, computations):
+ return [func(task) for task in computations]
+
+def test_incompatible_n_workers(u):
+ backend = ManyWorkersBackend(n_workers=2)
+ with pytest.raises(ValueError):
+ FrameAnalysis(u).run(backend=backend, n_workers=3)
+
+@pytest.mark.parametrize('run_class,backend,n_workers', [
+ (Parallelizable, 'not-existing-backend', 2),
+ (Parallelizable, 'not-existing-backend', None),
+ (SerialOnly, 'not-existing-backend', 2),
+ (SerialOnly, 'not-existing-backend', None),
+ (SerialOnly, 'multiprocessing', 2),
+ (SerialOnly, 'dask', None),
+ (ParallelizableWithDaskOnly, 'multiprocessing', None),
+ (ParallelizableWithDaskOnly, 'multiprocessing', 2),
+])
+def test_backend_configuration_fails(u, run_class, backend, n_workers):
+ u = mda.Universe(TPR, XTC) # dt = 100
+ with pytest.raises(ValueError):
+ _ = run_class(u.trajectory).run(backend=backend, n_workers=n_workers, stop=0)
+
+@pytest.mark.parametrize('run_class,backend,n_workers', [
+ (Parallelizable, CustomSerialBackend, 2),
+ (ParallelizableWithDaskOnly, CustomSerialBackend, 2),
+])
+def test_backend_configuration_works_when_unsupported_backend(u, run_class, backend, n_workers):
+ u = mda.Universe(TPR, XTC) # dt = 100
+ backend_instance = backend(n_workers=n_workers)
+ _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers, stop=0, unsupported_backend=True)
+
+@pytest.mark.parametrize('run_class,backend,n_workers', [
+ (Parallelizable, CustomSerialBackend, 1),
+ (ParallelizableWithDaskOnly, CustomSerialBackend, 1),
+])
+def test_custom_backend_works(u, run_class, backend, n_workers):
+ backend_instance = backend(n_workers=n_workers)
+ u = mda.Universe(TPR, XTC) # dt = 100
+ _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers, unsupported_backend=True)
+
+@pytest.mark.parametrize('run_class,backend_instance,n_workers', [
+ (Parallelizable, map, 1),
+ (SerialOnly, list, 1),
+ (ParallelizableWithDaskOnly, object, 1),
+])
+def test_fails_incorrect_custom_backend(u, run_class, backend_instance, n_workers):
+ u = mda.Universe(TPR, XTC) # dt = 100
+ with pytest.raises(ValueError):
+ _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers, unsupported_backend=True)
+
+ with pytest.raises(ValueError):
+ _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers)
+
+@pytest.mark.parametrize('run_class,backend,n_workers', [
+ (SerialOnly, CustomSerialBackend, 1),
+ (SerialOnly, 'multiprocessing', 1),
+ (SerialOnly, 'dask', 1),
+])
+def test_fails_for_unparallelizable(u, run_class, backend, n_workers):
+ u = mda.Universe(TPR, XTC) # dt = 100
+ with pytest.raises(ValueError):
+ if not isinstance(backend, str):
+ backend_instance = backend(n_workers=n_workers)
+ _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers, unsupported_backend=True)
+ else:
+ _ = run_class(u.trajectory).run(backend=backend, n_workers=n_workers, unsupported_backend=True)
+
+@pytest.mark.parametrize('run_kwargs,frames', [
+ ({}, np.arange(98)),
+ ({'start': 20}, np.arange(20, 98)),
+ ({'stop': 30}, np.arange(30)),
+ ({'step': 10}, np.arange(0, 98, 10))
+])
+def test_start_stop_step_parallel(u, run_kwargs, frames, client_FrameAnalysis):
+ # client_FrameAnalysis is defined [here](testsuite/MDAnalysisTests/analysis/conftest.py),
+ # and determines a set of parameters ('backend', 'n_workers'), taking only backends
+ # that are implemented for a given subclass, to run the test against.
+ an = FrameAnalysis(u.trajectory).run(**run_kwargs, **client_FrameAnalysis)
+ assert an.n_frames == len(frames)
+ assert_equal(an.found_frames, frames)
+ assert_equal(an.frames, frames, err_msg=FRAMES_ERR)
+ assert_almost_equal(an.times, frames+1, decimal=4, err_msg=TIMES_ERR)
+
@pytest.mark.parametrize('run_kwargs,frames', [
({}, np.arange(98)),
@@ -217,6 +227,22 @@ def test_frame_slice(u_xtc, run_kwargs, frames):
assert_equal(an.frames, frames, err_msg=FRAMES_ERR)
+@pytest.mark.parametrize('run_kwargs, frames', [
+ ({'frames': [4, 5, 6, 7, 8, 9]}, np.arange(4, 10)),
+ ({'frames': [0, 2, 4, 6, 8]}, np.arange(0, 10, 2)),
+ ({'frames': [4, 6, 8]}, np.arange(4, 10, 2)),
+ ({'frames': [0, 3, 4, 3, 5]}, [0, 3, 4, 3, 5]),
+ ({'frames': [True, True, False, True, False, True, True, False, True,
+ False]}, (0, 1, 3, 5, 6, 8)),
+])
+def test_frame_slice_parallel(run_kwargs, frames, client_FrameAnalysis):
+ u = mda.Universe(TPR, XTC) # dt = 100
+ an = FrameAnalysis(u.trajectory).run(**run_kwargs, **client_FrameAnalysis)
+ assert an.n_frames == len(frames)
+ assert_equal(an.found_frames, frames)
+ assert_equal(an.frames, frames, err_msg=FRAMES_ERR)
+
+
@pytest.mark.parametrize('run_kwargs', [
({'start': 4, 'frames': [4, 5, 6, 7, 8, 9]}),
({'stop': 6, 'frames': [0, 1, 2, 3, 4, 5]}),
@@ -226,28 +252,44 @@ def test_frame_slice(u_xtc, run_kwargs, frames):
({'start': 4, 'step': 2, 'frames': [4, 6, 8]}),
({'start': 0, 'stop': 0, 'step': 0, 'frames': [4, 6, 8]}),
])
-def test_frame_fail(u, run_kwargs):
+def test_frame_fail(u, run_kwargs, client_FrameAnalysis):
an = FrameAnalysis(u.trajectory)
msg = 'start/stop/step cannot be combined with frames'
with pytest.raises(ValueError, match=msg):
- an.run(**run_kwargs)
+ an.run(**client_FrameAnalysis, **run_kwargs)
+
+def test_parallelizable_transformations():
+ # pick any transformation that would allow
+ # for parallelizable attribute
+ from MDAnalysis.transformations import NoJump
+ u = mda.Universe(XTC)
+ u.trajectory.add_transformations(NoJump())
+ # test that serial works
+ FrameAnalysis(u.trajectory).run()
+
+ # test that parallel fails
+ with pytest.raises(ValueError):
+ FrameAnalysis(u.trajectory).run(backend='multiprocessing')
-def test_frame_bool_fail(u_xtc):
- an = FrameAnalysis(u_xtc.trajectory)
+def test_frame_bool_fail(client_FrameAnalysis):
+ u = mda.Universe(TPR, XTC) # dt = 100
+ an = FrameAnalysis(u.trajectory)
frames = [True, True, False]
msg = 'boolean index did not match indexed array along (axis|dimension) 0'
with pytest.raises(IndexError, match=msg):
- an.run(frames=frames)
+ an.run(**client_FrameAnalysis, frames=frames)
-def test_rewind(u_xtc):
- FrameAnalysis(u_xtc.trajectory).run(frames=[0, 2, 3, 5, 9])
- assert_equal(u_xtc.trajectory.ts.frame, 0)
+def test_rewind(client_FrameAnalysis):
+ u = mda.Universe(TPR, XTC) # dt = 100
+ an = FrameAnalysis(u.trajectory).run(**client_FrameAnalysis, frames=[0, 2, 3, 5, 9])
+ assert_equal(u.trajectory.ts.frame, 0)
-def test_frames_times(u_xtc):
- an = FrameAnalysis(u_xtc.trajectory).run(start=1, stop=8, step=2)
+def test_frames_times(client_FrameAnalysis):
+ u = mda.Universe(TPR, XTC) # dt = 100
+ an = FrameAnalysis(u.trajectory).run(start=1, stop=8, step=2, **client_FrameAnalysis)
frames = np.array([1, 3, 5, 7])
assert an.n_frames == len(frames)
assert_equal(an.found_frames, frames)
@@ -260,6 +302,24 @@ def test_verbose(u):
assert a._verbose
+def test_warn_nparts_nworkers(u):
+ a = FrameAnalysis(u.trajectory)
+ with pytest.warns(UserWarning):
+ a.run(backend='multiprocessing', n_workers=3, n_parts=2)
+
+
+@pytest.mark.parametrize(
+ "classname,is_parallelizable",
+ [
+ (base.AnalysisBase, False),
+ (base.AnalysisFromFunction, True),
+ (FrameAnalysis, True)
+ ]
+)
+def test_not_parallelizable(u, classname, is_parallelizable):
+ assert classname._analysis_algorithm_is_parallelizable == is_parallelizable
+
+
def test_verbose_progressbar(u, capsys):
FrameAnalysis(u.trajectory).run()
_, err = capsys.readouterr()
@@ -283,6 +343,12 @@ def test_verbose_progressbar_run_with_kwargs(u, capsys):
actual = err.strip().split('\r')[-1]
assert actual[:30] == expected[:30]
+
+def test_progressbar_multiprocessing(u):
+ with pytest.raises(ValueError):
+ FrameAnalysis(u.trajectory).run(backend='multiprocessing', verbose=True)
+
+
def test_incomplete_defined_analysis(u):
with pytest.raises(NotImplementedError):
IncompleteAnalysis(u.trajectory).run()
@@ -332,15 +398,18 @@ def test_results_type(u):
(20, 50, 2, 15),
(20, 50, None, 30)
])
-def test_AnalysisFromFunction(u, start, stop, step, nframes):
+def test_AnalysisFromFunction(u, start, stop, step, nframes, client_AnalysisFromFunction):
+ # client_AnalysisFromFunction is defined [here](testsuite/MDAnalysisTests/analysis/conftest.py),
+ # and determines a set of parameters ('backend', 'n_workers'), taking only backends
+ # that are implemented for a given subclass, to run the test against.
ana1 = base.AnalysisFromFunction(simple_function, mobile=u.atoms)
- ana1.run(start=start, stop=stop, step=step)
+ ana1.run(start=start, stop=stop, step=step, **client_AnalysisFromFunction)
ana2 = base.AnalysisFromFunction(simple_function, u.atoms)
- ana2.run(start=start, stop=stop, step=step)
+ ana2.run(start=start, stop=stop, step=step, **client_AnalysisFromFunction)
ana3 = base.AnalysisFromFunction(simple_function, u.trajectory, u.atoms)
- ana3.run(start=start, stop=stop, step=step)
+ ana3.run(start=start, stop=stop, step=step, **client_AnalysisFromFunction)
frames = []
times = []
@@ -367,26 +436,27 @@ def mass_xyz(atomgroup1, atomgroup2, masses):
return atomgroup1.positions * masses
-def test_AnalysisFromFunction_args_content(u):
+def test_AnalysisFromFunction_args_content(u, client_AnalysisFromFunction):
protein = u.select_atoms('protein')
masses = protein.masses.reshape(-1, 1)
another = mda.Universe(TPR, XTC).select_atoms("protein")
ans = base.AnalysisFromFunction(mass_xyz, protein, another, masses)
assert len(ans.args) == 3
- result = np.sum(ans.run().results.timeseries)
+ result = np.sum(ans.run(**client_AnalysisFromFunction).results.timeseries)
assert_allclose(result, -317054.67757345125, rtol=0, atol=1.5e-6)
+ assert_almost_equal(result, -317054.67757345125, decimal=6)
assert (ans.args[0] is protein) and (ans.args[1] is another)
assert ans._trajectory is protein.universe.trajectory
-def test_analysis_class():
+def test_analysis_class(client_AnalysisFromFunctionAnalysisClass):
ana_class = base.analysis_class(simple_function)
assert issubclass(ana_class, base.AnalysisBase)
assert issubclass(ana_class, base.AnalysisFromFunction)
u = mda.Universe(PSF, DCD)
step = 2
- ana = ana_class(u.atoms).run(step=step)
+ ana = ana_class(u.atoms).run(step=step, **client_AnalysisFromFunctionAnalysisClass)
results = []
for ts in u.trajectory[::step]:
diff --git a/testsuite/MDAnalysisTests/analysis/test_pca.py b/testsuite/MDAnalysisTests/analysis/test_pca.py
index c72971c30ed..ec874b900fe 100644
--- a/testsuite/MDAnalysisTests/analysis/test_pca.py
+++ b/testsuite/MDAnalysisTests/analysis/test_pca.py
@@ -133,6 +133,12 @@ def test_no_frames(u):
PCA(u, select=SELECTION).run(stop=1)
+def test_can_run_frames(u):
+ atoms = u.select_atoms(SELECTION)
+ u.transfer_to_memory()
+ PCA(u, select=SELECTION).run(frames=[0,1])
+
+
def test_can_run_frames(u):
atoms = u.select_atoms(SELECTION)
u.transfer_to_memory()
diff --git a/testsuite/MDAnalysisTests/analysis/test_results.py b/testsuite/MDAnalysisTests/analysis/test_results.py
new file mode 100644
index 00000000000..97d299de101
--- /dev/null
+++ b/testsuite/MDAnalysisTests/analysis/test_results.py
@@ -0,0 +1,173 @@
+import pickle
+from collections import UserDict
+
+import numpy as np
+import pytest
+from MDAnalysis.analysis import results as results_module
+from numpy.testing import assert_equal
+
+
+class Test_Results:
+ @pytest.fixture
+ def results(self):
+ return results_module.Results(a=1, b=2)
+
+ def test_get(self, results):
+ assert results.a == results["a"] == 1
+
+ def test_no_attr(self, results):
+ msg = "'Results' object has no attribute 'c'"
+ with pytest.raises(AttributeError, match=msg):
+ results.c
+
+ def test_set_attr(self, results):
+ value = [1, 2, 3, 4]
+ results.c = value
+ assert results.c is results["c"] is value
+
+ def test_set_key(self, results):
+ value = [1, 2, 3, 4]
+ results["c"] = value
+ assert results.c is results["c"] is value
+
+ @pytest.mark.parametrize("key", dir(UserDict) + ["data"])
+ def test_existing_dict_attr(self, results, key):
+ msg = f"'{key}' is a protected dictionary attribute"
+ with pytest.raises(AttributeError, match=msg):
+ results[key] = None
+
+ @pytest.mark.parametrize("key", dir(UserDict) + ["data"])
+ def test_wrong_init_type(self, key):
+ msg = f"'{key}' is a protected dictionary attribute"
+ with pytest.raises(AttributeError, match=msg):
+ results_module.Results(**{key: None})
+
+ @pytest.mark.parametrize("key", ("0123", "0j", "1.1", "{}", "a b"))
+ def test_weird_key(self, results, key):
+ msg = f"'{key}' is not a valid attribute"
+ with pytest.raises(ValueError, match=msg):
+ results[key] = None
+
+ def test_setattr_modify_item(self, results):
+ mylist = [1, 2]
+ mylist2 = [3, 4]
+ results.myattr = mylist
+ assert results.myattr is mylist
+ results["myattr"] = mylist2
+ assert results.myattr is mylist2
+ mylist2.pop(0)
+ assert len(results.myattr) == 1
+ assert results.myattr is mylist2
+
+ def test_setitem_modify_item(self, results):
+ mylist = [1, 2]
+ mylist2 = [3, 4]
+ results["myattr"] = mylist
+ assert results.myattr is mylist
+ results.myattr = mylist2
+ assert results.myattr is mylist2
+ mylist2.pop(0)
+ assert len(results["myattr"]) == 1
+ assert results["myattr"] is mylist2
+
+ def test_delattr(self, results):
+ assert hasattr(results, "a")
+ delattr(results, "a")
+ assert not hasattr(results, "a")
+
+ def test_missing_delattr(self, results):
+ assert not hasattr(results, "d")
+ msg = "'Results' object has no attribute 'd'"
+ with pytest.raises(AttributeError, match=msg):
+ delattr(results, "d")
+
+ def test_pop(self, results):
+ assert hasattr(results, "a")
+ results.pop("a")
+ assert not hasattr(results, "a")
+
+ def test_update(self, results):
+ assert not hasattr(results, "spudda")
+ results.update({"spudda": "fett"})
+ assert results.spudda == "fett"
+
+ def test_update_data_fail(self, results):
+ msg = f"'data' is a protected dictionary attribute"
+ with pytest.raises(AttributeError, match=msg):
+ results.update({"data": 0})
+
+ def test_pickle(self, results):
+ results_p = pickle.dumps(results)
+ results_new = pickle.loads(results_p)
+
+ @pytest.mark.parametrize(
+ "args, kwargs, length",
+ [
+ (({"darth": "tater"},), {}, 1),
+ ([], {"darth": "tater"}, 1),
+ (({"darth": "tater"},), {"yam": "solo"}, 2),
+ (({"darth": "tater"},), {"darth": "vader"}, 1),
+ ],
+ )
+ def test_initialize_arguments(self, args, kwargs, length):
+ results = results_module.Results(*args, **kwargs)
+ ref = dict(*args, **kwargs)
+ assert ref == results
+ assert len(results) == length
+
+ def test_different_instances(self, results):
+ new_results = results_module.Results(darth="tater")
+ assert new_results.data is not results.data
+
+
+class Test_ResultsGroup:
+ @pytest.fixture
+ def results_0(self):
+ return results_module.Results(
+ sequence=[0],
+ ndarray_mean=np.array([0, 0, 0]),
+ ndarray_sum=np.array([0, 0, 0, 0]),
+ float=0.0,
+ float_sum=0.0,
+ )
+
+ @pytest.fixture
+ def results_1(self):
+ return results_module.Results(
+ sequence=[1],
+ ndarray_mean=np.array([1, 1, 1]),
+ ndarray_sum=np.array([1, 1, 1, 1]),
+ float=1.0,
+ float_sum=1.0,
+ )
+
+ @pytest.fixture
+ def merger(self):
+ RG = results_module.ResultsGroup
+ lookup = {
+ "sequence": RG.flatten_sequence,
+ "ndarray_mean": RG.ndarray_mean,
+ "ndarray_sum": RG.ndarray_sum,
+ "float": RG.float_mean,
+ "float_sum": lambda floats: sum(floats),
+ }
+ return RG(lookup=lookup)
+
+ @pytest.mark.parametrize("n", [1, 2, 5, 14])
+ def test_all_results(self, results_0, results_1, merger, n):
+ from itertools import cycle
+
+ objects = [obj for obj, _ in zip(cycle([results_0, results_1]), range(n))]
+
+ arr = [i for _, i in zip(range(n), cycle([0, 1]))]
+ answers = {
+ "sequence": arr,
+ "ndarray_mean": [np.mean(arr) for _ in range(3)],
+ "ndarray_sum": [np.sum(arr) for _ in range(4)],
+ "float": np.mean(arr),
+ "float_sum": np.sum(arr),
+ }
+
+ results = merger.merge(objects)
+ for attr, merged_value in results.items():
+ assert_equal(merged_value, answers.get(attr), err_msg=f"{attr=}, {merged_value=}, {arr=}, {objects=}")
diff --git a/testsuite/MDAnalysisTests/analysis/test_rms.py b/testsuite/MDAnalysisTests/analysis/test_rms.py
index 56d2e459e94..d42993feb46 100644
--- a/testsuite/MDAnalysisTests/analysis/test_rms.py
+++ b/testsuite/MDAnalysisTests/analysis/test_rms.py
@@ -186,95 +186,104 @@ def correct_values_backbone_group(self):
return [[0, 1, 0, 0, 0],
[49, 50, 4.6997, 1.9154, 2.7139]]
- def test_rmsd(self, universe, correct_values):
+ def test_rmsd(self, universe, correct_values, client_RMSD):
+ # client_RMSD is defined in testsuite/analysis/conftest.py
+ # among with other testing fixtures. During testing, it will
+ # collect all possible backends and reasonable number of workers
+ # for a given AnalysisBase subclass, and extend the tests
+ # to run with all of them.
RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA')
- RMSD.run(step=49)
+ RMSD.run(step=49, **client_RMSD)
assert_almost_equal(RMSD.results.rmsd, correct_values, 4,
err_msg="error: rmsd profile should match" +
"test values")
- def test_rmsd_frames(self, universe, correct_values):
+ def test_rmsd_frames(self, universe, correct_values, client_RMSD):
RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA')
- RMSD.run(frames=[0, 49])
+ RMSD.run(frames=[0, 49], **client_RMSD)
assert_almost_equal(RMSD.results.rmsd, correct_values, 4,
err_msg="error: rmsd profile should match" +
"test values")
- def test_rmsd_unicode_selection(self, universe, correct_values):
+ def test_rmsd_unicode_selection(self, universe, correct_values, client_RMSD):
RMSD = MDAnalysis.analysis.rms.RMSD(universe, select=u'name CA')
- RMSD.run(step=49)
+ RMSD.run(step=49, **client_RMSD)
assert_almost_equal(RMSD.results.rmsd, correct_values, 4,
err_msg="error: rmsd profile should match" +
"test values")
- def test_rmsd_atomgroup_selections(self, universe):
+ def test_rmsd_atomgroup_selections(self, universe, client_RMSD):
# see Issue #1684
R1 = MDAnalysis.analysis.rms.RMSD(universe.atoms,
- select="resid 1-30").run()
+ select="resid 1-30").run(**client_RMSD)
R2 = MDAnalysis.analysis.rms.RMSD(universe.atoms.select_atoms("name CA"),
- select="resid 1-30").run()
+ select="resid 1-30").run(**client_RMSD)
assert not np.allclose(R1.results.rmsd[:, 2], R2.results.rmsd[:, 2])
- def test_rmsd_single_frame(self, universe):
+ def test_rmsd_single_frame(self, universe, client_RMSD):
RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA',
- ).run(start=5, stop=6)
+ ).run(start=5, stop=6, **client_RMSD)
single_frame = [[5, 6, 0.91544906]]
assert_almost_equal(RMSD.results.rmsd, single_frame, 4,
err_msg="error: rmsd profile should match" +
"test values")
- def test_mass_weighted(self, universe, correct_values):
+ def test_mass_weighted(self, universe, correct_values, client_RMSD):
# mass weighting the CA should give the same answer as weighing
# equally because all CA have the same mass
RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA',
- weights='mass').run(step=49)
+ weights='mass').run(step=49, **client_RMSD)
assert_almost_equal(RMSD.results.rmsd, correct_values, 4,
err_msg="error: rmsd profile should match"
"test values")
- def test_custom_weighted(self, universe, correct_values_mass):
- RMSD = MDAnalysis.analysis.rms.RMSD(universe, weights="mass").run(step=49)
+ def test_custom_weighted(self, universe, correct_values_mass, client_RMSD):
+ RMSD = MDAnalysis.analysis.rms.RMSD(universe, weights="mass").run(step=49, **client_RMSD)
assert_almost_equal(RMSD.results.rmsd, correct_values_mass, 4,
err_msg="error: rmsd profile should match"
"test values")
- def test_weights_mass_is_mass_weighted(self, universe):
+ def test_weights_mass_is_mass_weighted(self, universe, client_RMSD):
RMSD_mass = MDAnalysis.analysis.rms.RMSD(universe,
- weights="mass").run(step=49)
+ weights="mass").run(step=49, **client_RMSD)
RMSD_cust = MDAnalysis.analysis.rms.RMSD(universe,
- weights=universe.atoms.masses).run(step=49)
+ weights=universe.atoms.masses).run(step=49, **client_RMSD)
assert_almost_equal(RMSD_mass.results.rmsd, RMSD_cust.results.rmsd, 4,
err_msg="error: rmsd profiles should match for 'mass' "
"and universe.atoms.masses")
- def test_custom_weighted_list(self, universe, correct_values_mass):
+ def test_custom_weighted_list(self, universe, correct_values_mass, client_RMSD):
weights = universe.atoms.masses
RMSD = MDAnalysis.analysis.rms.RMSD(universe,
- weights=list(weights)).run(step=49)
+ weights=list(weights)).run(step=49, **client_RMSD)
assert_almost_equal(RMSD.results.rmsd, correct_values_mass, 4,
err_msg="error: rmsd profile should match" +
"test values")
- def test_custom_groupselection_weights_applied_1D_array(self, universe):
+ def test_custom_groupselection_weights_applied_1D_array(self, universe, client_RMSD):
RMSD = MDAnalysis.analysis.rms.RMSD(universe,
select='backbone',
groupselections=['name CA and resid 1-5', 'name CA and resid 1'],
weights=None,
- weights_groupselections=[[1, 0, 0, 0, 0], None]).run(step=49)
+ weights_groupselections=[[1, 0, 0, 0, 0], None]).run(step=49,
+ **client_RMSD
+ )
assert_almost_equal(RMSD.results.rmsd.T[3], RMSD.results.rmsd.T[4], 4,
err_msg="error: rmsd profile should match "
"for applied weight array and selected resid")
- def test_custom_groupselection_weights_applied_mass(self, universe, correct_values_mass):
+ def test_custom_groupselection_weights_applied_mass(self, universe, correct_values_mass, client_RMSD):
RMSD = MDAnalysis.analysis.rms.RMSD(universe,
select='backbone',
groupselections=['all', 'all'],
weights=None,
weights_groupselections=['mass',
- universe.atoms.masses]).run(step=49)
+ universe.atoms.masses]).run(step=49,
+ **client_RMSD
+ )
assert_almost_equal(RMSD.results.rmsd.T[3], RMSD.results.rmsd.T[4], 4,
err_msg="error: rmsd profile should match "
@@ -315,22 +324,23 @@ def test_rmsd_list_of_weights_wrong_length(self, universe):
weights='mass',
weights_groupselections=[None])
- def test_rmsd_group_selections(self, universe, correct_values_group):
+ def test_rmsd_group_selections(self, universe, correct_values_group, client_RMSD):
RMSD = MDAnalysis.analysis.rms.RMSD(universe,
groupselections=['backbone', 'name CA']
- ).run(step=49)
+ ).run(step=49, **client_RMSD)
assert_almost_equal(RMSD.results.rmsd, correct_values_group, 4,
err_msg="error: rmsd profile should match"
"test values")
def test_rmsd_backbone_and_group_selection(self, universe,
- correct_values_backbone_group):
+ correct_values_backbone_group,
+ client_RMSD):
RMSD = MDAnalysis.analysis.rms.RMSD(
universe,
reference=universe,
select="backbone",
groupselections=['backbone and resid 1:10',
- 'backbone and resid 10:20']).run(step=49)
+ 'backbone and resid 10:20']).run(step=49, **client_RMSD)
assert_almost_equal(
RMSD.results.rmsd, correct_values_backbone_group, 4,
err_msg="error: rmsd profile should match test values")
@@ -349,7 +359,7 @@ def test_mass_mismatches(self, universe):
RMSD = MDAnalysis.analysis.rms.RMSD(universe,
reference=reference)
- def test_ref_mobile_mass_mismapped(self, universe,correct_values_mass_add_ten):
+ def test_ref_mobile_mass_mismapped(self, universe,correct_values_mass_add_ten, client_RMSD):
reference = MDAnalysis.Universe(PSF, DCD)
universe.atoms.masses = universe.atoms.masses + 10
RMSD = MDAnalysis.analysis.rms.RMSD(universe,
@@ -357,7 +367,7 @@ def test_ref_mobile_mass_mismapped(self, universe,correct_values_mass_add_ten):
select='all',
weights='mass',
tol_mass=100)
- RMSD.run(step=49)
+ RMSD.run(step=49, **client_RMSD)
assert_almost_equal(RMSD.results.rmsd, correct_values_mass_add_ten, 4,
err_msg="error: rmsd profile should match "
"between true values and calculated values")
@@ -370,9 +380,9 @@ def test_group_selections_unequal_len(self, universe):
reference=reference,
groupselections=['resname MET', 'type NH3'])
- def test_rmsd_attr_warning(self, universe):
+ def test_rmsd_attr_warning(self, universe, client_RMSD):
RMSD = MDAnalysis.analysis.rms.RMSD(
- universe, select='name CA').run(stop=2)
+ universe, select='name CA').run(stop=2, **client_RMSD)
wmsg = "The `rmsd` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
@@ -384,22 +394,22 @@ class TestRMSF(object):
def universe(self):
return mda.Universe(GRO, XTC)
- def test_rmsf(self, universe):
+ def test_rmsf(self, universe, client_RMSF):
rmsfs = rms.RMSF(universe.select_atoms('name CA'))
- rmsfs.run()
+ rmsfs.run(**client_RMSF)
test_rmsfs = np.load(rmsfArray)
assert_almost_equal(rmsfs.results.rmsf, test_rmsfs, 5,
err_msg="error: rmsf profile should match test "
"values")
- def test_rmsf_single_frame(self, universe):
- rmsfs = rms.RMSF(universe.select_atoms('name CA')).run(start=5, stop=6)
+ def test_rmsf_single_frame(self, universe, client_RMSF):
+ rmsfs = rms.RMSF(universe.select_atoms('name CA')).run(start=5, stop=6, **client_RMSF)
assert_almost_equal(rmsfs.results.rmsf, 0, 5,
err_msg="error: rmsfs should all be zero")
- def test_rmsf_identical_frames(self, universe, tmpdir):
+ def test_rmsf_identical_frames(self, universe, tmpdir, client_RMSF):
outfile = os.path.join(str(tmpdir), 'rmsf.xtc')
@@ -410,13 +420,35 @@ def test_rmsf_identical_frames(self, universe, tmpdir):
universe = mda.Universe(GRO, outfile)
rmsfs = rms.RMSF(universe.select_atoms('name CA'))
- rmsfs.run()
+ rmsfs.run(**client_RMSF)
assert_almost_equal(rmsfs.results.rmsf, 0, 5,
err_msg="error: rmsfs should all be 0")
- def test_rmsf_attr_warning(self, universe):
- rmsfs = rms.RMSF(universe.select_atoms('name CA')).run(stop=2)
+ def test_rmsf_attr_warning(self, universe, client_RMSF):
+ rmsfs = rms.RMSF(universe.select_atoms('name CA')).run(stop=2, **client_RMSF)
wmsg = "The `rmsf` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
assert_equal(rmsfs.rmsf, rmsfs.results.rmsf)
+
+
+@pytest.mark.parametrize(
+ "classname,is_parallelizable",
+ [
+ (MDAnalysis.analysis.rms.RMSD, True),
+ (MDAnalysis.analysis.rms.RMSF, False),
+ ]
+)
+def test_not_parallelizable(classname, is_parallelizable):
+ assert classname._analysis_algorithm_is_parallelizable == is_parallelizable
+
+
+@pytest.mark.parametrize(
+ "classname,backends",
+ [
+ (MDAnalysis.analysis.rms.RMSD, ('serial', 'multiprocessing', 'dask',)),
+ (MDAnalysis.analysis.rms.RMSF, ('serial',)),
+ ]
+)
+def test_supported_backends(classname, backends):
+ assert classname.get_supported_backends() == backends
diff --git a/testsuite/MDAnalysisTests/lib/test_util.py b/testsuite/MDAnalysisTests/lib/test_util.py
index 4ff1832b7de..cd641133586 100644
--- a/testsuite/MDAnalysisTests/lib/test_util.py
+++ b/testsuite/MDAnalysisTests/lib/test_util.py
@@ -2232,3 +2232,11 @@ def test_which():
with pytest.warns(DeprecationWarning, match=wmsg):
assert util.which('python') == shutil.which('python')
+
+
+@pytest.mark.parametrize(
+ "modulename,is_installed",
+ [("math", True), ("sys", True), ("some_weird_module_name", False)],
+)
+def test_is_installed(modulename, is_installed):
+ assert util.is_installed(modulename) == is_installed