diff --git a/.github/actions/setup-deps/action.yaml b/.github/actions/setup-deps/action.yaml index e5a0794450c..cceae40f99c 100644 --- a/.github/actions/setup-deps/action.yaml +++ b/.github/actions/setup-deps/action.yaml @@ -62,6 +62,8 @@ inputs: default: 'chemfiles-python>=0.9' clustalw: default: 'clustalw=2.1' + dask: + default: 'dask' distopia: default: 'distopia>=0.2.0' h5py: @@ -134,6 +136,7 @@ runs: ${{ inputs.biopython }} ${{ inputs.chemfiles-python }} ${{ inputs.clustalw }} + ${{ inputs.dask }} ${{ inputs.distopia }} ${{ inputs.gsd }} ${{ inputs.h5py }} diff --git a/package/CHANGELOG b/package/CHANGELOG index 750ddcf1cba..056282c6214 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -52,6 +52,8 @@ Fixes * Fix groups.py doctests using sphinx directives (Issue #3925, PR #4374) Enhancements + * Introduce parallelization API to `AnalysisBase` and to `analysis.rms.RMSD` class + (Issue #4158, PR #4304) * Improve error message for `AtomGroup.unwrap()` when bonds are not present.(Issue #4436, PR #4642) * Add `analysis.DSSP` module for protein secondary structure assignment, based on [pydssp](https://github.com/ShintaroMinami/PyDSSP) * Added a tqdm progress bar for `MDAnalysis.analysis.pca.PCA.transform()` diff --git a/package/MDAnalysis/analysis/__init__.py b/package/MDAnalysis/analysis/__init__.py index 5983bb8b61c..25450b759b9 100644 --- a/package/MDAnalysis/analysis/__init__.py +++ b/package/MDAnalysis/analysis/__init__.py @@ -23,6 +23,7 @@ __all__ = [ 'align', + 'backends', 'base', 'contacts', 'density', @@ -45,6 +46,7 @@ 'pca', 'psa', 'rdf', + 'results', 'rms', 'waterdynamics', ] diff --git a/package/MDAnalysis/analysis/backends.py b/package/MDAnalysis/analysis/backends.py new file mode 100644 index 00000000000..38917cb2ae7 --- /dev/null +++ b/package/MDAnalysis/analysis/backends.py @@ -0,0 +1,333 @@ +"""Analysis backends --- :mod:`MDAnalysis.analysis.backends` +============================================================ + +.. versionadded:: 2.8.0 + + +The :mod:`backends` module provides :class:`BackendBase` base class to +implement custom execution backends for +:meth:`MDAnalysis.analysis.base.AnalysisBase.run` and its +subclasses. + +.. SeeAlso:: :ref:`parallel-analysis` + +.. _backends: + +Backends +-------- + +Three built-in backend classes are provided: + +* *serial*: :class:`BackendSerial`, that is equivalent to using no + parallelization and is the default + +* *multiprocessing*: :class:`BackendMultiprocessing` that supports + parallelization via standard Python :mod:`multiprocessing` module + and uses default :mod:`pickle` serialization + +* *dask*: :class:`BackendDask`, that uses the same process-based + parallelization as :class:`BackendMultiprocessing`, but different + serialization algorithm via `dask `_ (see `dask + serialization algorithms + `_ for details) + +Classes +------- + +""" +import warnings +from typing import Callable +from MDAnalysis.lib.util import is_installed + + +class BackendBase: + """Base class for backend implementation. + + Initializes an instance and performs checks for its validity, such as + ``n_workers`` and possibly other ones. + + Parameters + ---------- + n_workers : int + number of workers (usually, processes) over which the work is split + + Examples + -------- + .. code-block:: python + + from MDAnalysis.analysis.backends import BackendBase + + class ThreadsBackend(BackendBase): + def apply(self, func, computations): + from multiprocessing.dummy import Pool + + with Pool(processes=self.n_workers) as pool: + results = pool.map(func, computations) + return results + + import MDAnalysis as mda + from MDAnalysis.tests.datafiles import PSF, DCD + from MDAnalysis.analysis.rms import RMSD + + u = mda.Universe(PSF, DCD) + ref = mda.Universe(PSF, DCD) + + R = RMSD(u, ref) + + n_workers = 2 + backend = ThreadsBackend(n_workers=n_workers) + R.run(backend=backend, unsupported_backend=True) + + .. warning:: + Using `ThreadsBackend` above will lead to erroneous results, since it + is an educational example. Do not use it for real analysis. + + + .. versionadded:: 2.8.0 + + """ + + def __init__(self, n_workers: int): + self.n_workers = n_workers + self._validate() + + def _get_checks(self): + """Get dictionary with ``condition: error_message`` pairs that ensure the + validity of the backend instance + + Returns + ------- + dict + dictionary with ``condition: error_message`` pairs that will get + checked during ``_validate()`` run + """ + return { + isinstance(self.n_workers, int) and self.n_workers > 0: + f"n_workers should be positive integer, got {self.n_workers=}", + } + + def _get_warnings(self): + """Get dictionary with ``condition: warning_message`` pairs that ensure + the good usage of the backend instance + + Returns + ------- + dict + dictionary with ``condition: warning_message`` pairs that will get + checked during ``_validate()`` run + """ + return dict() + + def _validate(self): + """Check correctness (e.g. ``dask`` is installed if using ``backend='dask'``) + and good usage (e.g. ``n_workers=1`` if backend is serial) of the backend + + Raises + ------ + ValueError + if one of the conditions in :meth:`_get_checks` is ``True`` + """ + for check, msg in self._get_checks().items(): + if not check: + raise ValueError(msg) + for check, msg in self._get_warnings().items(): + if not check: + warnings.warn(msg) + + def apply(self, func: Callable, computations: list) -> list: + """map function `func` to all tasks in the `computations` list + + Main method that will get called when using an instance of + ``BackendBase``. It is equivalent to running ``[func(item) for item in + computations]`` while using the parallel backend capabilities. + + Parameters + ---------- + func : Callable + function to be called on each of the tasks in computations list + computations : list + computation tasks to apply function to + + Returns + ------- + list + list of results of the function + + """ + raise NotImplementedError + + +class BackendSerial(BackendBase): + """A built-in backend that does serial execution of the function, without any + parallelization. + + Parameters + ---------- + n_workers : int + Is ignored in this class, and if ``n_workers`` > 1, a warning will be + given. + + + .. versionadded:: 2.8.0 + """ + + def _get_warnings(self): + """Get dictionary with ``condition: warning_message`` pairs that ensure + the good usage of the backend instance. Here, it checks if the number + of workers is not 1, otherwise gives warning. + + Returns + ------- + dict + dictionary with ``condition: warning_message`` pairs that will get + checked during ``_validate()`` run + """ + return { + self.n_workers == 1: + "n_workers is ignored when executing with backend='serial'" + } + + def apply(self, func: Callable, computations: list) -> list: + """ + Serially applies `func` to each task object in ``computations``. + + Parameters + ---------- + func : Callable + function to be called on each of the tasks in computations list + computations : list + computation tasks to apply function to + + Returns + ------- + list + list of results of the function + """ + return [func(task) for task in computations] + + +class BackendMultiprocessing(BackendBase): + """A built-in backend that executes a given function using the + :meth:`multiprocessing.Pool.map ` method. + + Parameters + ---------- + n_workers : int + number of processes in :class:`multiprocessing.Pool + ` to distribute the workload + between. Must be a positive integer. + + Examples + -------- + + .. code-block:: python + + from MDAnalysis.analysis.backends import BackendMultiprocessing + import multiprocessing as mp + + backend_obj = BackendMultiprocessing(n_workers=mp.cpu_count()) + + + .. versionadded:: 2.8.0 + + """ + + def apply(self, func: Callable, computations: list) -> list: + """Applies `func` to each object in ``computations`` using `multiprocessing`'s `Pool.map`. + + Parameters + ---------- + func : Callable + function to be called on each of the tasks in computations list + computations : list + computation tasks to apply function to + + Returns + ------- + list + list of results of the function + """ + from multiprocessing import Pool + + with Pool(processes=self.n_workers) as pool: + results = pool.map(func, computations) + return results + + +class BackendDask(BackendBase): + """A built-in backend that executes a given function with *dask*. + + Execution is performed with the :func:`dask.compute` function of + :class:`dask.delayed.Delayed` object (created with + :func:`dask.delayed.delayed`) with ``scheduler='processes'`` and + ``chunksize=1`` (this ensures uniform distribution of tasks among + processes). Requires the `dask package `_ + to be `installed `_. + + Parameters + ---------- + n_workers : int + number of processes in to distribute the workload + between. Must be a positive integer. Workers are actually + :class:`multiprocessing.pool.Pool` processes, but they use a different and + more flexible `serialization protocol + `_. + + Examples + -------- + + .. code-block:: python + + from MDAnalysis.analysis.backends import BackendDask + import multiprocessing as mp + + backend_obj = BackendDask(n_workers=mp.cpu_count()) + + + .. versionadded:: 2.8.0 + + """ + + def apply(self, func: Callable, computations: list) -> list: + """Applies `func` to each object in ``computations``. + + Parameters + ---------- + func : Callable + function to be called on each of the tasks in computations list + computations : list + computation tasks to apply function to + + Returns + ------- + list + list of results of the function + """ + from dask.delayed import delayed + import dask + + computations = [delayed(func)(task) for task in computations] + results = dask.compute(computations, + scheduler="processes", + chunksize=1, + num_workers=self.n_workers)[0] + return results + + def _get_checks(self): + """Get dictionary with ``condition: error_message`` pairs that ensure the + validity of the backend instance. Here checks if ``dask`` module is + installed in the environment. + + Returns + ------- + dict + dictionary with ``condition: error_message`` pairs that will get + checked during ``_validate()`` run + """ + base_checks = super()._get_checks() + checks = { + is_installed("dask"): + ("module 'dask' is missing. Please install 'dask': " + "https://docs.dask.org/en/stable/install.html") + } + return base_checks | checks diff --git a/package/MDAnalysis/analysis/base.py b/package/MDAnalysis/analysis/base.py index 04a0d25b506..47a7eccd137 100644 --- a/package/MDAnalysis/analysis/base.py +++ b/package/MDAnalysis/analysis/base.py @@ -103,16 +103,42 @@ contains a step-by-step example for writing analysis tools with :class:`AnalysisBase`. - .. _`Writing your own trajectory analysis`: https://userguide.mdanalysis.org/stable/examples/analysis/custom_trajectory_analysis.html +If your analysis is operating independently on each frame, you might consider +making it **parallelizable** via adding a :meth:`get_supported_backends` method, +and appropriate aggregation function for each of its results. For example, if +you have your :meth:`_single_frame` method storing important values under +:attr:`self.results.timeseries`, you will write: + +.. code-block:: python + + class MyAnalysis(AnalysisBase): + _analysis_algorithm_is_parallelizable = True + + @classmethod + def get_supported_backends(cls): + return ('serial', 'multiprocessing', 'dask',) + + + def _get_aggregator(self): + return ResultsGroup(lookup={'timeseries': ResultsGroup.ndarray_vstack}) + +See :mod:`MDAnalysis.analysis.results` for more on aggregating results. + +.. SeeAlso:: + + :ref:`parallel-analysis` + + + Classes ------- -The :class:`Results` and :class:`AnalysisBase` classes are the essential -building blocks for almost all MDAnalysis tools in the +The :class:`MDAnalysis.results.Results` and :class:`AnalysisBase` classes +are the essential building blocks for almost all MDAnalysis tools in the :mod:`MDAnalysis.analysis` module. They aim to be easily useable and extendable. @@ -121,103 +147,22 @@ tools if only the single-frame analysis function needs to be written. """ -from collections import UserDict import inspect -import logging import itertools +import logging +import warnings +from functools import partial +from typing import Iterable, Union import numpy as np -from MDAnalysis import coordinates -from MDAnalysis.core.groups import AtomGroup -from MDAnalysis.lib.log import ProgressBar - -logger = logging.getLogger(__name__) - - -class Results(UserDict): - r"""Container object for storing results. - - :class:`Results` are dictionaries that provide two ways by which values - can be accessed: by dictionary key ``results["value_key"]`` or by object - attribute, ``results.value_key``. :class:`Results` stores all results - obtained from an analysis after calling :meth:`~AnalysisBase.run()`. - - The implementation is similar to the :class:`sklearn.utils.Bunch` - class in `scikit-learn`_. - - .. _`scikit-learn`: https://scikit-learn.org/ - - Raises - ------ - AttributeError - If an assigned attribute has the same name as a default attribute. - - ValueError - If a key is not of type ``str`` and therefore is not able to be - accessed by attribute. - - Examples - -------- - >>> from MDAnalysis.analysis.base import Results - >>> results = Results(a=1, b=2) - >>> results['b'] - 2 - >>> results.b - 2 - >>> results.a = 3 - >>> results['a'] - 3 - >>> results.c = [1, 2, 3, 4] - >>> results['c'] - [1, 2, 3, 4] - - - .. versionadded:: 2.0.0 - """ - - def _validate_key(self, key): - if key in dir(self): - raise AttributeError(f"'{key}' is a protected dictionary " - "attribute") - elif isinstance(key, str) and not key.isidentifier(): - raise ValueError(f"'{key}' is not a valid attribute") - - def __init__(self, *args, **kwargs): - kwargs = dict(*args, **kwargs) - if "data" in kwargs.keys(): - raise AttributeError(f"'data' is a protected dictionary attribute") - self.__dict__["data"] = {} - self.update(kwargs) - - def __setitem__(self, key, item): - self._validate_key(key) - super().__setitem__(key, item) - - def __setattr__(self, attr, val): - if attr == 'data': - super().__setattr__(attr, val) - else: - self.__setitem__(attr, val) - - def __getattr__(self, attr): - try: - return self[attr] - except KeyError as err: - raise AttributeError("'Results' object has no " - f"attribute '{attr}'") from err - - def __delattr__(self, attr): - try: - del self[attr] - except KeyError as err: - raise AttributeError("'Results' object has no " - f"attribute '{attr}'") from err +from .. import coordinates +from ..core.groups import AtomGroup +from ..lib.log import ProgressBar - def __getstate__(self): - return self.data +from .backends import BackendDask, BackendMultiprocessing, BackendSerial, BackendBase +from .results import Results, ResultsGroup - def __setstate__(self, state): - self.data = state +logger = logging.getLogger(__name__) class AnalysisBase(object): @@ -231,8 +176,8 @@ class AnalysisBase(object): To define a new Analysis, :class:`AnalysisBase` needs to be subclassed and :meth:`_single_frame` must be defined. It is also possible to define :meth:`_prepare` and :meth:`_conclude` for pre- and post-processing. - All results should be stored as attributes of the :class:`Results` - container. + All results should be stored as attributes of the + :class:`MDAnalysis.analysis.results.Results` container. Parameters ---------- @@ -310,15 +255,140 @@ def _conclude(self): .. versionchanged:: 2.0.0 Added :attr:`results` + .. versionchanged:: 2.8.0 + Added ability to run analysis in parallel using either a + built-in backend (`multiprocessing` or `dask`) or a custom + `backends.BackendBase` instance with an implemented `apply` method + that is used to run the computations. """ + @classmethod + def get_supported_backends(cls): + """Tuple with backends supported by the core library for a given class. + User can pass either one of these values as ``backend=...`` to + :meth:`run()` method, or a custom object that has ``apply`` method + (see documentation for :meth:`run()`): + + - 'serial': no parallelization + - 'multiprocessing': parallelization using `multiprocessing.Pool` + - 'dask': parallelization using `dask.delayed.compute()`. Requires + installation of `mdanalysis[dask]` + + If you want to add your own backend to an existing class, pass a + :class:`backends.BackendBase` subclass (see its documentation to learn + how to implement it properly), and specify ``unsupported_backend=True``. + + Returns + ------- + tuple + names of built-in backends that can be used in :meth:`run(backend=...)` + + + .. versionadded:: 2.8.0 + """ + return ("serial",) + + # class authors: override _analysis_algorithm_is_parallelizable + # in derived classes and only set to True if you have confirmed + # that your algorithm works reliably when parallelized with + # the split-apply-combine approach (see docs) + _analysis_algorithm_is_parallelizable = False + + @property + def parallelizable(self): + """Boolean mark showing that a given class can be parallelizable with + split-apply-combine procedure. Namely, if we can safely distribute + :meth:`_single_frame` to multiple workers and then combine them with a + proper :meth:`_conclude` call. If set to ``False``, no backends except + for ``serial`` are supported. + + .. note:: If you want to check parallelizability of the whole class, without + explicitly creating an instance of the class, see + :attr:`_analysis_algorithm_is_parallelizable`. Note that you + setting it to other value will break things if the algorithm + behind the analysis is not trivially parallelizable. + + + Returns + ------- + bool + if a given ``AnalysisBase`` subclass instance + is parallelizable with split-apply-combine, or not + + + .. versionadded:: 2.8.0 + """ + return self._analysis_algorithm_is_parallelizable + def __init__(self, trajectory, verbose=False, **kwargs): self._trajectory = trajectory self._verbose = verbose self.results = Results() - def _setup_frames(self, trajectory, start=None, stop=None, step=None, - frames=None): + def _define_run_frames(self, trajectory, + start=None, stop=None, step=None, frames=None + ) -> Union[slice, np.ndarray]: + """Defines limits for the whole run, as passed by self.run() arguments + + Parameters + ---------- + trajectory : mda.Reader + a trajectory Reader + start : int, optional + start frame of analysis, by default None + stop : int, optional + stop frame of analysis, by default None + step : int, optional + number of frames to skip between each analysed frame, by default None + frames : array_like, optional + array of integers or booleans to slice trajectory; cannot be + combined with ``start``, ``stop``, ``step``; by default None + + Returns + ------- + Union[slice, np.ndarray] + Appropriate slicer for the trajectory that would give correct iteraction + order via trajectory[slicer] + + Raises + ------ + ValueError + if *both* `frames` and at least one of ``start``, ``stop``, or ``step`` + is provided (i.e. set to not ``None`` value). + + + .. versionadded:: 2.8.0 + """ + self._trajectory = trajectory + if frames is not None: + if not all(opt is None for opt in [start, stop, step]): + raise ValueError("start/stop/step cannot be combined with frames") + slicer = frames + else: + start, stop, step = trajectory.check_slice_indices(start, stop, step) + slicer = slice(start, stop, step) + self.start, self.stop, self.step = start, stop, step + return slicer + + def _prepare_sliced_trajectory(self, slicer: Union[slice, np.ndarray]): + """Prepares sliced trajectory for use in subsequent parallel computations: + namely, assigns self._sliced_trajectory and its appropriate attributes, + self.n_frames, self.frames and self.times. + + Parameters + ---------- + slicer : Union[slice, np.ndarray] + appropriate slicer for the trajectory + + + .. versionadded:: 2.8.0 + """ + self._sliced_trajectory = self._trajectory[slicer] + self.n_frames = len(self._sliced_trajectory) + self.frames = np.zeros(self.n_frames, dtype=int) + self.times = np.zeros(self.n_frames) + + def _setup_frames(self, trajectory, start=None, stop=None, step=None, frames=None): """Pass a Reader object and define the desired iteration pattern through the trajectory @@ -334,63 +404,303 @@ def _setup_frames(self, trajectory, start=None, stop=None, step=None, number of frames to skip between each analysed frame frames : array_like, optional array of integers or booleans to slice trajectory; cannot be - combined with `start`, `stop`, `step` + combined with ``start``, ``stop``, ``step`` .. versionadded:: 2.2.0 Raises ------ ValueError - if *both* `frames` and at least one of `start`, `stop`, or `frames` - is provided (i.e., set to another value than ``None``) + if *both* `frames` and at least one of ``start``, ``stop``, or + ``frames`` is provided (i.e., set to another value than ``None``) .. versionchanged:: 1.0.0 Added .frames and .times arrays as attributes - + .. versionchanged:: 2.2.0 Added ability to iterate through trajectory by passing a list of frame indices in the `frames` keyword argument - + + .. versionchanged:: 2.8.0 + Split function into two: :meth:`_define_run_frames` and + :meth:`_prepare_sliced_trajectory`: first one defines the limits + for the whole run and is executed once during :meth:`run` in + :meth:`_setup_frames`, second one prepares sliced trajectory for + each of the workers and gets executed twice: one time in + :meth:`_setup_frames` for the whole trajectory, second time in + :meth:`_compute` for each of the computation groups. """ - self._trajectory = trajectory - if frames is not None: - if not all(opt is None for opt in [start, stop, step]): - raise ValueError("start/stop/step cannot be combined with " - "frames") - slicer = frames - else: - start, stop, step = trajectory.check_slice_indices(start, stop, - step) - slicer = slice(start, stop, step) - self._sliced_trajectory = trajectory[slicer] - self.start = start - self.stop = stop - self.step = step - self.n_frames = len(self._sliced_trajectory) - self.frames = np.zeros(self.n_frames, dtype=int) - self.times = np.zeros(self.n_frames) + slicer = self._define_run_frames(trajectory, start, stop, step, frames) + self._prepare_sliced_trajectory(slicer) def _single_frame(self): """Calculate data from a single frame of trajectory Don't worry about normalising, just deal with a single frame. + Attributes accessible during your calculations: + + - ``self._frame_index``: index of the frame in results array + - ``self._ts`` -- Timestep instance + - ``self._sliced_trajectory`` -- trajectory that you're iterating over + - ``self.results`` -- :class:`MDAnalysis.analysis.results.Results` instance + holding run results initialized in :meth:`_prepare`. """ raise NotImplementedError("Only implemented in child classes") def _prepare(self): - """Set things up before the analysis loop begins""" + """ + Set things up before the analysis loop begins. + + Notes + ----- + ``self.results`` is initialized already in :meth:`self.__init__` with an + empty instance of :class:`MDAnalysis.analysis.results.Results` object. + You can still call your attributes as if they were usual ones, + ``Results`` just keeps track of that to be able to run a proper + aggregation after a parallel run, if necessary. + """ pass # pylint: disable=unnecessary-pass def _conclude(self): """Finalize the results you've gathered. Called at the end of the :meth:`run` method to finish everything up. + + Notes + ----- + Aggregation of results from individual workers happens in + :meth:`self.run()`, so here you have to implement everything as if you + had a non-parallel run. If you want to enable proper aggregation for + parallel runs for you analysis class, implement ``self._get_aggregator`` + and check :mod:`MDAnalysis.analysis.results` for how to use it. """ pass # pylint: disable=unnecessary-pass - def run(self, start=None, stop=None, step=None, frames=None, - verbose=None, *, progressbar_kwargs=None): + def _compute(self, indexed_frames: np.ndarray, + verbose: bool = None, + *, progressbar_kwargs={}) -> "AnalysisBase": + """Perform the calculation on a balanced slice of frames + that have been setup prior to that using _setup_computation_groups() + + Parameters + ---------- + indexed_frames : np.ndarray + np.ndarray of (n, 2) shape, where first column is frame iteration + indices and second is frame numbers + + verbose : bool, optional + Turn on verbosity + + progressbar_kwargs : dict, optional + ProgressBar keywords with custom parameters regarding progress bar + position, etc; see :class:`MDAnalysis.lib.log.ProgressBar` + for full list. + + + .. versionadded:: 2.8.0 + """ + logger.info("Choosing frames to analyze") + # if verbose unchanged, use class default + verbose = getattr(self, "_verbose", False) if verbose is None else verbose + + frames = indexed_frames[:, 1] + + logger.info("Starting preparation") + self._prepare_sliced_trajectory(slicer=frames) + self._prepare() + if len(frames) == 0: # if `frames` were empty in `run` or `stop=0` + return self + + for idx, ts in enumerate(ProgressBar( + self._sliced_trajectory, + verbose=verbose, + **progressbar_kwargs)): + self._frame_index = idx # accessed later by subclasses + self._ts = ts + self.frames[idx] = ts.frame + self.times[idx] = ts.time + self._single_frame() + logger.info("Finishing up") + return self + + def _setup_computation_groups( + self, n_parts: int, + start: int = None, stop: int = None, step: int = None, + frames: Union[slice, np.ndarray] = None + ) -> list[np.ndarray]: + """ + Splits the trajectory frames, defined by ``start/stop/step`` or + ``frames``, into ``n_parts`` even groups, preserving their indices. + + Parameters + ---------- + n_parts : int + number of parts to split the workload into + start : int, optional + start frame + stop : int, optional + stop frame + step : int, optional + step size for analysis (1 means to read every frame) + frames : array_like, optional + array of integers or booleans to slice trajectory; ``frames`` can + only be used *instead* of ``start``, ``stop``, and ``step``. Setting + *both* ``frames`` and at least one of ``start``, ``stop``, ``step`` + to a non-default value will raise a :exc:`ValueError`. + + Raises + ------ + ValueError + if *both* ``frames`` and at least one of ``start``, ``stop``, or + ``frames`` is provided (i.e., set to another value than ``None``) + + Returns + ------- + computation_groups : list[np.ndarray] + list of (n, 2) shaped np.ndarrays with frame indices and numbers + + + .. versionadded:: 2.8.0 + """ + if frames is None: + start, stop, step = self._trajectory.check_slice_indices(start, stop, step) + used_frames = np.arange(start, stop, step) + elif not all(opt is None for opt in [start, stop, step]): + raise ValueError("start/stop/step cannot be combined with frames") + else: + used_frames = frames + + if all(isinstance(obj, bool) for obj in used_frames): + arange = np.arange(len(used_frames)) + used_frames = arange[used_frames] + + # similar to list(enumerate(frames)) + enumerated_frames = np.vstack([np.arange(len(used_frames)), used_frames]).T + + return np.array_split(enumerated_frames, n_parts) + + def _configure_backend( + self, + backend: Union[str, BackendBase], + n_workers: int, + unsupported_backend: bool = False + ) -> BackendBase: + """Matches a passed backend string value with class attributes + :attr:`parallelizable` and :meth:`get_supported_backends()` + to check if downstream calculations can be performed. + + Parameters + ---------- + backend : Union[str, BackendBase] + backend to be used: + - ``str`` is matched to a builtin backend (one of "serial", + "multiprocessing" and "dask") + - ``BackendBase`` subclass is checked for the presence of + an :meth:`apply` method + n_workers : int + positive integer with number of workers (processes, in case of + built-in backends) to split the work between + unsupported_backend : bool, optional + if you want to run your custom backend on a parallelizable class + that has not been tested by developers, by default ``False`` + + Returns + ------- + BackendBase + instance of a ``BackendBase`` class that will be used for computations + + Raises + ------ + ValueError + if :attr:`parallelizable` is set to ``False`` but backend is + not ``serial`` + ValueError + if :attr:`parallelizable` is ``True`` and custom backend instance is used + without specifying ``unsupported_backend=True`` + ValueError + if your trajectory has associated parallelizable transformations + but backend is not serial + ValueError + if ``n_workers`` was specified twice -- in the run() method and durin + ``__init__`` of a custom backend + ValueError + if your backend object instance doesn't have an ``apply`` method + + + .. versionadded:: 2.8.0 + """ + builtin_backends = { + "serial": BackendSerial, + "multiprocessing": BackendMultiprocessing, + "dask": BackendDask + } + + backend_class = builtin_backends.get(backend, backend) + supported_backend_classes = [ + builtin_backends.get(b) + for b in self.get_supported_backends() + ] + + # check for serial-only classes + if not self.parallelizable and backend_class is not BackendSerial: + raise ValueError(f"Can not parallelize class {self.__class__}") + + # make sure user enabled 'unsupported_backend=True' for custom classes + if (not unsupported_backend and self.parallelizable + and backend_class not in supported_backend_classes): + raise ValueError(( + f"Must specify 'unsupported_backend=True'" + f"if you want to use a custom {backend_class=} for {self.__class__}" + )) + + # check for the presence of parallelizable transformations + if backend_class is not BackendSerial and any( + not t.parallelizable + for t in self._trajectory.transformations): + raise ValueError(( + "Trajectory should not have " + "associated unparallelizable transformations")) + + # conclude mapping from string to backend class if it's a builtin backend + if isinstance(backend, str): + return backend_class(n_workers=n_workers) + + # make sure we haven't specified n_workers twice + if ( + isinstance(backend, BackendBase) + and n_workers is not None + and hasattr(backend, 'n_workers') + and backend.n_workers != n_workers + ): + raise ValueError(( + f"n_workers specified twice: in {backend.n_workers=}" + f"and in run({n_workers=}). Remove it from run()" + )) + + # or pass along an instance of the class itself + # after ensuring it has apply method + if not isinstance(backend, BackendBase) or not hasattr(backend, "apply"): + raise ValueError(( + f"{backend=} is invalid: should have 'apply' method " + "and be instance of MDAnalysis.analysis.backends.BackendBase" + )) + return backend + + def run( + self, + start: int = None, + stop: int = None, + step: int = None, + frames: Iterable = None, + verbose: bool = None, + n_workers: int = None, + n_parts: int = None, + backend: Union[str, BackendBase] = None, + *, + unsupported_backend: bool = False, + progressbar_kwargs=None, + ): """Perform the calculation Parameters @@ -402,20 +712,42 @@ def run(self, start=None, stop=None, step=None, frames=None, step : int, optional number of frames to skip between each analysed frame frames : array_like, optional - array of integers or booleans to slice trajectory; `frames` can - only be used *instead* of `start`, `stop`, and `step`. Setting - *both* `frames` and at least one of `start`, `stop`, `step` to a - non-default value will raise a :exc:`ValueError`. + array of integers or booleans to slice trajectory; ``frames`` can + only be used *instead* of ``start``, ``stop``, and ``step``. Setting + *both* ``frames`` and at least one of ``start``, ``stop``, ``step`` + to a non-default value will raise a :exc:`ValueError`. .. versionadded:: 2.2.0 - verbose : bool, optional Turn on verbosity progressbar_kwargs : dict, optional ProgressBar keywords with custom parameters regarding progress bar - position, etc; see :class:`MDAnalysis.lib.log.ProgressBar` for full - list. + position, etc; see :class:`MDAnalysis.lib.log.ProgressBar` + for full list. Available only for ``backend='serial'`` + backend : Union[str, BackendBase], optional + By default, performs calculations in a serial fashion. + Otherwise, user can choose a backend: ``str`` is matched to a + builtin backend (one of ``serial``, ``multiprocessing`` and + ``dask``), or a :class:`MDAnalysis.analysis.results.BackendBase` + subclass. + + .. versionadded:: 2.8.0 + n_workers : int + positive integer with number of workers (processes, in case of + built-in backends) to split the work between + + .. versionadded:: 2.8.0 + n_parts : int, optional + number of parts to split computations across. Can be more than + number of workers. + + .. versionadded:: 2.8.0 + unsupported_backend : bool, optional + if you want to run your custom backend on a parallelizable class + that has not been tested by developers, by default False + + .. versionadded:: 2.8.0 .. versionchanged:: 2.2.0 @@ -425,34 +757,88 @@ def run(self, start=None, stop=None, step=None, frames=None, .. versionchanged:: 2.5.0 Add `progressbar_kwargs` parameter, allowing to modify description, position etc of tqdm progressbars - """ - logger.info("Choosing frames to analyze") - # if verbose unchanged, use class default - verbose = getattr(self, '_verbose', - False) if verbose is None else verbose - self._setup_frames(self._trajectory, start=start, stop=stop, - step=step, frames=frames) - logger.info("Starting preparation") - self._prepare() - logger.info("Starting analysis loop over %d trajectory frames", - self.n_frames) - if progressbar_kwargs is None: - progressbar_kwargs = {} + .. versionchanged:: 2.8.0 + Introduced ``backend``, ``n_workers``, ``n_parts`` and + ``unsupported_backend`` keywords, and refactored the method logic to + support parallelizable execution. + """ + # default to serial execution + backend = "serial" if backend is None else backend + + progressbar_kwargs = {} if progressbar_kwargs is None else progressbar_kwargs + if ((progressbar_kwargs or verbose) and + not (backend == "serial" or + isinstance(backend, BackendSerial))): + raise ValueError("Can not display progressbar with non-serial backend") + + # if number of workers not specified, try getting the number from + # the backend instance if possible, or set to 1 + if n_workers is None: + n_workers = ( + backend.n_workers + if isinstance(backend, BackendBase) and hasattr(backend, "n_workers") + else 1 + ) + + # set n_parts and check that is has a reasonable value + n_parts = n_workers if n_parts is None else n_parts + + # do this as early as possible to check client parameters + # before any computations occur + executor = self._configure_backend( + backend=backend, + n_workers=n_workers, + unsupported_backend=unsupported_backend) + if ( + hasattr(executor, "n_workers") and n_parts < executor.n_workers + ): # using executor's value here for non-default executors + warnings.warn(( + f"Analysis not making use of all workers: " + f"{executor.n_workers=} is greater than {n_parts=}")) + + # start preparing the run + worker_func = partial( + self._compute, + progressbar_kwargs=progressbar_kwargs, + verbose=verbose) + self._setup_frames( + trajectory=self._trajectory, + start=start, stop=stop, step=step, frames=frames) + computation_groups = self._setup_computation_groups( + start=start, stop=stop, step=step, frames=frames, n_parts=n_parts + ) + + # get all results from workers in other processes. + # we need `AnalysisBase` classes + # since they hold `frames`, `times` and `results` attributes + remote_objects: list["AnalysisBase"] = executor.apply( + worker_func, computation_groups) + self.frames = np.hstack([obj.frames for obj in remote_objects]) + self.times = np.hstack([obj.times for obj in remote_objects]) + + # aggregate results from results obtained in remote workers + remote_results = [obj.results for obj in remote_objects] + results_aggregator = self._get_aggregator() + self.results = results_aggregator.merge(remote_results) - for i, ts in enumerate(ProgressBar( - self._sliced_trajectory, - verbose=verbose, - **progressbar_kwargs)): - self._frame_index = i - self._ts = ts - self.frames[i] = ts.frame - self.times[i] = ts.time - self._single_frame() - logger.info("Finishing up") self._conclude() return self + def _get_aggregator(self) -> ResultsGroup: + """Returns a default aggregator that takes entire results + if there is a single object, and raises ValueError otherwise + + Returns + ------- + ResultsGroup + aggregating object + + + .. versionadded:: 2.8.0 + """ + return ResultsGroup(lookup=None) + class AnalysisFromFunction(AnalysisBase): r"""Create an :class:`AnalysisBase` from a function working on AtomGroups @@ -503,8 +889,18 @@ def rotation_matrix(mobile, ref): .. versionchanged:: 2.0.0 Former :attr:`results` are now stored as :attr:`results.timeseries` + + .. versionchanged:: 2.8.0 + Added :meth:`get_supported_backends()`, introducing 'serial', 'multiprocessing' + and 'dask' backends. """ + _analysis_algorithm_is_parallelizable = True + + @classmethod + def get_supported_backends(cls): + return ("serial", "multiprocessing", "dask") + def __init__(self, function, trajectory=None, *args, **kwargs): if (trajectory is not None) and (not isinstance( trajectory, coordinates.base.ProtoReader)): @@ -531,9 +927,11 @@ def __init__(self, function, trajectory=None, *args, **kwargs): def _prepare(self): self.results.timeseries = [] + def _get_aggregator(self): + return ResultsGroup({"timeseries": ResultsGroup.flatten_sequence}) + def _single_frame(self): - self.results.timeseries.append(self.function(*self.args, - **self.kwargs)) + self.results.timeseries.append(self.function(*self.args, **self.kwargs)) def _conclude(self): self.results.frames = self.frames @@ -594,8 +992,11 @@ def RotationMatrix(mobile, ref): class WrapperClass(AnalysisFromFunction): def __init__(self, trajectory=None, *args, **kwargs): - super(WrapperClass, self).__init__(function, trajectory, - *args, **kwargs) + super(WrapperClass, self).__init__(function, trajectory, *args, **kwargs) + + @classmethod + def get_supported_backends(cls): + return ("serial", "dask") return WrapperClass @@ -633,9 +1034,11 @@ def _filter_baseanalysis_kwargs(function, kwargs): base_argspec = inspect.getargspec(AnalysisBase.__init__) n_base_defaults = len(base_argspec.defaults) - base_kwargs = {name: val - for name, val in zip(base_argspec.args[-n_base_defaults:], - base_argspec.defaults)} + base_kwargs = { + name: val + for name, val in zip(base_argspec.args[-n_base_defaults:], + base_argspec.defaults) + } try: # pylint: disable=deprecated-method @@ -648,7 +1051,8 @@ def _filter_baseanalysis_kwargs(function, kwargs): if base_kw in argspec.args: raise ValueError( "argument name '{}' clashes with AnalysisBase argument." - "Now allowed are: {}".format(base_kw, base_kwargs.keys())) + "Now allowed are: {}".format(base_kw, base_kwargs.keys()) + ) base_args = {} for argname, default in base_kwargs.items(): diff --git a/package/MDAnalysis/analysis/results.py b/package/MDAnalysis/analysis/results.py new file mode 100644 index 00000000000..8aa2062d2bc --- /dev/null +++ b/package/MDAnalysis/analysis/results.py @@ -0,0 +1,321 @@ +"""Analysis results and their aggregation --- :mod:`MDAnalysis.analysis.results` +================================================================================ + +Module introduces two classes, :class:`Results` and :class:`ResultsGroup`, +used for storing and aggregating data in +:meth:`MDAnalysis.analysis.base.AnalysisBase.run()`, respectively. + + +Classes +------- + +The :class:`Results` class is an extension of a built-in dictionary +type, that holds all assigned attributes in :attr:`self.data` and +allows for access either via dict-like syntax, or via class-like syntax: + +.. code-block:: python + + from MDAnalysis.analysis.results import Results + r = Results() + r.array = [1, 2, 3, 4] + assert r['array'] == r.array == [1, 2, 3, 4] + + +The :class:`ResultsGroup` can merge multiple :class:`Results` objects. +It is mainly used by :class:`MDAnalysis.analysis.base.AnalysisBase` class, +that uses :meth:`ResultsGroup.merge()` method to aggregate results from +multiple workers, initialized during a parallel run: + +.. code-block:: python + + from MDAnalysis.analysis.results import Results, ResultsGroup + import numpy as np + + r1, r2 = Results(), Results() + r1.masses = [1, 2, 3, 4, 5] + r2.masses = [0, 0, 0, 0] + r1.vectors = np.arange(10).reshape(5, 2) + r2.vectors = np.arange(8).reshape(4, 2) + + group = ResultsGroup( + lookup = { + 'masses': ResultsGroup.flatten_sequence, + 'vectors': ResultsGroup.ndarray_vstack + } + ) + + r = group.merge([r1, r2]) + assert r.masses == list((*r1.masses, *r2.masses)) + assert (r.vectors == np.vstack([r1.vectors, r2.vectors])).all() +""" +from collections import UserDict +import numpy as np +from typing import Callable, Sequence + + +class Results(UserDict): + r"""Container object for storing results. + + :class:`Results` are dictionaries that provide two ways by which values + can be accessed: by dictionary key ``results["value_key"]`` or by object + attribute, ``results.value_key``. :class:`Results` stores all results + obtained from an analysis after calling :meth:`~AnalysisBase.run()`. + + The implementation is similar to the :class:`sklearn.utils.Bunch` + class in `scikit-learn`_. + + .. _`scikit-learn`: https://scikit-learn.org/ + .. _`sklearn.utils.Bunch`: https://scikit-learn.org/stable/modules/generated/sklearn.utils.Bunch.html + + Raises + ------ + AttributeError + If an assigned attribute has the same name as a default attribute. + + ValueError + If a key is not of type ``str`` and therefore is not able to be + accessed by attribute. + + Examples + -------- + >>> from MDAnalysis.analysis.base import Results + >>> results = Results(a=1, b=2) + >>> results['b'] + 2 + >>> results.b + 2 + >>> results.a = 3 + >>> results['a'] + 3 + >>> results.c = [1, 2, 3, 4] + >>> results['c'] + [1, 2, 3, 4] + + + .. versionadded:: 2.0.0 + + .. versionchanged:: 2.8.0 + Moved :class:`Results` to :mod:`MDAnalysis.analysis.results` + """ + + def _validate_key(self, key): + if key in dir(self): + raise AttributeError(f"'{key}' is a protected dictionary attribute") + elif isinstance(key, str) and not key.isidentifier(): + raise ValueError(f"'{key}' is not a valid attribute") + + def __init__(self, *args, **kwargs): + kwargs = dict(*args, **kwargs) + if "data" in kwargs.keys(): + raise AttributeError(f"'data' is a protected dictionary attribute") + self.__dict__["data"] = {} + self.update(kwargs) + + def __setitem__(self, key, item): + self._validate_key(key) + super().__setitem__(key, item) + + def __setattr__(self, attr, val): + if attr == "data": + super().__setattr__(attr, val) + else: + self.__setitem__(attr, val) + + def __getattr__(self, attr): + try: + return self[attr] + except KeyError as err: + raise AttributeError(f"'Results' object has no attribute '{attr}'") from err + + def __delattr__(self, attr): + try: + del self[attr] + except KeyError as err: + raise AttributeError(f"'Results' object has no attribute '{attr}'") from err + + def __getstate__(self): + return self.data + + def __setstate__(self, state): + self.data = state + + +class ResultsGroup: + """ + Management and aggregation of results stored in :class:`Results` instances. + + A :class:`ResultsGroup` is an optional description for :class:`Result` "dictionaries" + that are used in analysis classes based on :class:`AnalysisBase`. For each *key* in a + :class:`Result` it describes how multiple pieces of the data held under the key are + to be aggregated. This approach is necessary when parts of a trajectory are analyzed + independently (e.g., in parallel) and then need to me merged (with :meth:`merge`) to + obtain a complete data set. + + Parameters + ---------- + lookup : dict[str, Callable], optional + aggregation functions lookup dict, by default None + + Examples + -------- + + .. code-block:: python + + from MDAnalysis.analysis.results import ResultsGroup, Results + group = ResultsGroup(lookup={'mass': ResultsGroup.float_mean}) + obj1 = Results(mass=1) + obj2 = Results(mass=3) + assert {'mass': 2.0} == group.merge([obj1, obj2]) + + + .. code-block:: python + + # you can also set `None` for those attributes that you want to skip + lookup = {'mass': ResultsGroup.float_mean, 'trajectory': None} + group = ResultsGroup(lookup) + objects = [Results(mass=1, skip=None), Results(mass=3, skip=object)] + assert group.merge(objects, require_all_aggregators=False) == {'mass': 2.0} + + .. versionadded:: 2.8.0 + """ + + def __init__(self, lookup: dict[str, Callable] = None): + self._lookup = lookup + + def merge(self, objects: Sequence[Results], require_all_aggregators: bool = True) -> Results: + """Merge multiple Results into a single Results instance. + + Merge multiple :class:`Results` instances into a single one, using the + `lookup` dictionary to determine the appropriate aggregator functions for + each named results attribute. If the resulting object only contains a single + element, it just returns it without using any aggregators. + + Parameters + ---------- + objects : Sequence[Results] + Multiple :class:`Results` instances with the same data attributes. + require_all_aggregators : bool, optional + if True, raise an exception when no aggregation function for a + particular argument is found. Allows to skip aggregation for the + parameters that aren't needed in the final object -- + see :class:`ResultsGroup`. + + Returns + ------- + Results + merged :class:`Results` + + Raises + ------ + ValueError + if no aggregation function for a key is found and ``require_all_aggregators=True`` + """ + if len(objects) == 1: + merged_results = objects[0] + return merged_results + + merged_results = Results() + for key in objects[0].keys(): + agg_function = self._lookup.get(key, None) + if agg_function is not None: + results_of_t = [obj[key] for obj in objects] + merged_results[key] = agg_function(results_of_t) + elif require_all_aggregators: + raise ValueError(f"No aggregation function for {key=}") + return merged_results + + @staticmethod + def flatten_sequence(arrs: list[list]): + """Flatten a list of lists into a list + + Parameters + ---------- + arrs : list[list] + list of lists + + Returns + ------- + list + flattened list + """ + return [item for sublist in arrs for item in sublist] + + @staticmethod + def ndarray_sum(arrs: list[np.ndarray]): + """sums an ndarray along ``axis=0`` + + Parameters + ---------- + arrs : list[np.ndarray] + list of input arrays. Must have the same shape. + + Returns + ------- + np.ndarray + sum of input arrays + """ + return np.array(arrs).sum(axis=0) + + @staticmethod + def ndarray_mean(arrs: list[np.ndarray]): + """calculates mean of input ndarrays along ``axis=0`` + + Parameters + ---------- + arrs : list[np.ndarray] + list of input arrays. Must have the same shape. + + Returns + ------- + np.ndarray + mean of input arrays + """ + return np.array(arrs).mean(axis=0) + + @staticmethod + def float_mean(floats: list[float]): + """calculates mean of input float values + + Parameters + ---------- + floats : list[float] + list of float values + + Returns + ------- + float + mean value + """ + return np.array(floats).mean() + + @staticmethod + def ndarray_hstack(arrs: list[np.ndarray]): + """Performs horizontal stack of input arrays + + Parameters + ---------- + arrs : list[np.ndarray] + input numpy arrays + + Returns + ------- + np.ndarray + result of stacking + """ + return np.hstack(arrs) + + @staticmethod + def ndarray_vstack(arrs: list[np.ndarray]): + """Performs vertical stack of input arrays + + Parameters + ---------- + arrs : list[np.ndarray] + input numpy arrays + + Returns + ------- + np.ndarray + result of stacking + """ + return np.vstack(arrs) diff --git a/package/MDAnalysis/analysis/rms.py b/package/MDAnalysis/analysis/rms.py index 36067d76b6d..afb11ed7d2e 100644 --- a/package/MDAnalysis/analysis/rms.py +++ b/package/MDAnalysis/analysis/rms.py @@ -166,10 +166,10 @@ import logging import warnings -import MDAnalysis.lib.qcprot as qcp -from MDAnalysis.analysis.base import AnalysisBase -from MDAnalysis.exceptions import SelectionError, NoDataError -from MDAnalysis.lib.util import asiterable, iterable, get_weights +from ..lib import qcprot as qcp +from ..analysis.base import AnalysisBase, ResultsGroup +from ..exceptions import SelectionError +from ..lib.util import asiterable, iterable, get_weights logger = logging.getLogger('MDAnalysis.analysis.rmsd') @@ -358,8 +358,17 @@ class RMSD(AnalysisBase): .. versionchanged:: 2.0.0 :attr:`rmsd` results are now stored in a :class:`MDAnalysis.analysis.base.Results` instance. - + .. versionchanged:: 2.8.0 + introduced a :meth:`get_supported_backends` allowing for execution on with + ``multiprocessing`` and ``dask`` backends. """ + _analysis_algorithm_is_parallelizable = True + + @classmethod + def get_supported_backends(cls): + return ('serial', 'multiprocessing', 'dask',) + + def __init__(self, atomgroup, reference=None, select='all', groupselections=None, weights=None, weights_groupselections=False, tol_mass=0.1, ref_frame=0, **kwargs): @@ -670,6 +679,9 @@ def _prepare(self): self._mobile_coordinates64 = self.mobile_atoms.positions.copy().astype(np.float64) + def _get_aggregator(self): + return ResultsGroup(lookup={'rmsd': ResultsGroup.ndarray_vstack}) + def _single_frame(self): mobile_com = self.mobile_atoms.center(self.weights_select).astype(np.float64) self._mobile_coordinates64[:] = self.mobile_atoms.positions @@ -739,6 +751,11 @@ class RMSF(AnalysisBase): in the array :attr:`RMSF.results.rmsf`. """ + + @classmethod + def get_supported_backends(cls): + return ('serial',) + def __init__(self, atomgroup, **kwargs): r"""Parameters ---------- diff --git a/package/MDAnalysis/lib/util.py b/package/MDAnalysis/lib/util.py index 8600c390e1b..666a8c49279 100644 --- a/package/MDAnalysis/lib/util.py +++ b/package/MDAnalysis/lib/util.py @@ -49,6 +49,11 @@ .. autofunction:: format_from_filename_extension .. autofunction:: guess_format +Modules and packages +-------------------- + +.. autofunction:: is_installed + Streams ------- @@ -147,6 +152,7 @@ .. autofunction:: convert_aa_code .. autofunction:: parse_residue .. autofunction:: conv_float +.. autofunction:: atoi Class decorators ---------------- @@ -205,6 +211,7 @@ from functools import wraps import textwrap import weakref +import importlib import itertools import mmtf @@ -1333,6 +1340,7 @@ def fixedwidth_bins(delta, xmin, xmax): dx = 0.5 * (N * _delta - _length) # add half of the excess to each end return {'Nbins': N, 'delta': _delta, 'min': _xmin - dx, 'max': _xmax + dx} + def get_weights(atoms, weights): """Check that a `weights` argument is compatible with `atoms`. @@ -2592,3 +2600,17 @@ def atoi(s: str) -> int: return int(''.join(itertools.takewhile(str.isdigit, s.strip()))) except ValueError: return 0 + + +def is_installed(modulename: str): + """Checks if module is installed + + Parameters + ---------- + modulename : str + name of the module to be tested + + + .. versionadded:: 2.8.0 + """ + return importlib.util.find_spec(modulename) is not None diff --git a/package/doc/sphinx/source/conf.py b/package/doc/sphinx/source/conf.py index dfc6c606a13..40a096c0275 100644 --- a/package/doc/sphinx/source/conf.py +++ b/package/doc/sphinx/source/conf.py @@ -351,4 +351,5 @@ class KeyStyle(UnsrtStyle): 'waterdynamics': ('https://www.mdanalysis.org/waterdynamics/', None), 'pathsimanalysis': ('https://www.mdanalysis.org/PathSimAnalysis/', None), 'mdahole2': ('https://www.mdanalysis.org/mdahole2/', None), + 'dask': ('https://docs.dask.org/en/stable/', None), } diff --git a/package/doc/sphinx/source/documentation_pages/analysis/backends.rst b/package/doc/sphinx/source/documentation_pages/analysis/backends.rst new file mode 100644 index 00000000000..36f2a40cf11 --- /dev/null +++ b/package/doc/sphinx/source/documentation_pages/analysis/backends.rst @@ -0,0 +1,4 @@ +.. automodule:: MDAnalysis.analysis.backends + :members: + :private-members: + diff --git a/package/doc/sphinx/source/documentation_pages/analysis/base.rst b/package/doc/sphinx/source/documentation_pages/analysis/base.rst index 4eda92ceac6..2c2d884fc71 100644 --- a/package/doc/sphinx/source/documentation_pages/analysis/base.rst +++ b/package/doc/sphinx/source/documentation_pages/analysis/base.rst @@ -1,4 +1,4 @@ .. automodule:: MDAnalysis.analysis.base :members: - + :private-members: diff --git a/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst b/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst new file mode 100644 index 00000000000..91ae05fceca --- /dev/null +++ b/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst @@ -0,0 +1,358 @@ +.. -*- coding: utf-8 -*- + +.. _parallel-analysis: + +================= +Parallel analysis +================= + +.. versionadded:: 2.8.0 + Parallelization of analysis classes was added during Google Summer of Code + 2023 by `@marinegor `_ and MDAnalysis GSoC mentors. + +This section explains the implementation and background for +parallelization with the :class:`MDAnalysis.analysis.base.AnalysisBase`, what +users and developers need to know, when you should use parallelization (almost +always!), and when you should abstain from doing so (rarely). + + +How to use parallelization +========================== + +In order to use parallelization in a built-in analysis class ``SomeAnalysisClass``, +simply check which backends are available (see :ref:`backends` for backends +that are generally available), and then just enable them by providing +``backend='multiprocessing'`` and ``n_workers=...`` to ``SomeClass.run(...)`` +method: + +.. code-block:: python + + u = mda.Universe(...) + my_run = SomeClass(trajectory) + assert SomeClass.get_supported_backends() == ('serial', 'multiprocessing', 'dask') + + my_run.run(backend='multiprocessing', n_workers=12) + +For some classes, such as :class:`MDAnalysis.analysis.rms.RMSF` (in its current implementation), +split-apply-combine parallelization isn't possible, and running them will be +impossible with any but the ``serial`` backend. + +.. Note:: + + Parallelization is getting added to existing analysis classes. Initially, + only :class:`MDAnalysis.analysis.rms.RMSD` supports parallel analysis, but + we aim to increase support in future releases. + + +How does parallelization work +============================= + +The main idea behind its current implementation is that a trajectory analysis is +often trivially parallelizable, meaning you can analyze all frames +independently, and then merge them in a single object. This approach is also +known as "split-apply-combine", and isn't new to MDAnalysis users, since it was +first introduced in `PMDA`_ :footcite:p:`Fan2019`. +Version 2.8.0 of MDAnalysis brings this approach to the main library. + +.. _`PMDA`: https://github.com/mdanalysis/pmda + + +split-apply-combine +------------------- + +The following scheme explains the current :meth:`AnalysisBase.run +` protocol (user-implemented methods +are highlighted in orange): + +.. figure:: /images/AnalysisBase_parallel.png + + +In short, after checking input parameters and configuring the backend, +:class:`AnalysisBase <` splits all the +frames into *computation groups* (equally sized sequential groups of frames to +be processed by each worker). All groups then get **split** between workers of +a backend configured early, the main instance gets serialized and distributed +between workers, and then +:meth:`~MDAnalysis.analysis.base.AnalysisBase._compute()` method gets called +for all frames of a computation group. Within this method, a user-implemented +:meth:`~MDAnalysis.analysis.base.AnalysisBase._single_frame` method gets +**applied** to each frame in a computation group. After that, the main +instance gets an object that will **combine** all the objects from other +workers, and all instances get *merged* with an instance of +:class:`MDAnalysis.analysis.results.ResultsGroup`. Then, a normal +user-implemented :meth:`~MDAnalysis.analysis.base.AnalysisBase._compute` method +is called. + +Parallelization is fully compatible with existing code and does *not* break +any existing code pre-2.8.0. The parallelization protocol mimics the +single-process workflow where possible. Thus, user-implemented methods such as +:meth:`~MDAnalysis.analysis.base.AnalysisBase._prepare`, +:meth:`~MDAnalysis.analysis.base.AnalysisBase._single_frame` and +:meth:`~MDAnalysis.analysis.base.AnalysisBase._conclude` won't need to know +they are operating on an instance within the main python process, or on a +remote instance, since the executed code is the same in both cases. + + +Methods in ``AnalysisBase`` for parallelization +----------------------------------------------- + +For developers of new analysis tools +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to write your own *parallel* analysis class, you have to implement +:meth:`~MDAnalysis.analysis.base.AnalysisBase._prepare`, +:meth:`~MDAnalysis.analysis.base.AnalysisBase._single_frame` and +:meth:`~MDAnalysis.analysis.base.AnalysisBase._conclude`. You also have to +denote if your analysis can run in parallel by following the steps under +:ref:`adding-parallelization`. + + +For MDAnalysis developers +~~~~~~~~~~~~~~~~~~~~~~~~~ + +From a developer point of view, there are a few methods that are important in +order to understand how parallelization is implemented: + +#. :meth:`MDAnalysis.analysis.base.AnalysisBase._define_run_frames` +#. :meth:`MDAnalysis.analysis.base.AnalysisBase._prepare_sliced_trajectory` +#. :meth:`MDAnalysis.analysis.base.AnalysisBase._configure_backend` +#. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_computation_groups` +#. :meth:`MDAnalysis.analysis.base.AnalysisBase._compute` +#. :meth:`MDAnalysis.analysis.base.AnalysisBase._get_aggregator` + +The first two methods share the functionality of :meth:`_setup_frames`. +:meth:`_define_run_frames` is run once during analysis, as it checks that input +parameters `start`, `stop`, `step` or `frames` are consistent with the given +trajectory and prepares the ``slicer`` object that defines the iteration +pattern through the trajectory. :meth:`_prepare_sliced_trajectory` assigns to +the :attr:`self._sliced_trajectory` attribute, computes the number of frames in +it, and fills the :attr:`self.frames` and :attr:`self.times` arrays. In case +the computation will be later split between other processes, this method will +be called again on each of the computation groups. + +The method :meth:`_configure_backend` performs basic health checks for a given +analysis class -- namely, it compares a given backend (if it's a :class:`str` +instance, such as ``'multiprocessing'``) with the list of builtin backends (and +also the backends implemented for a given analysis subclass), and configures a +:class:`MDAnalysis.analysis.backends.BackendBase` instance accordingly. If the +user decides to provide a custom backend (any subclass of +:class:`MDAnalysis.analysis.backends.BackendBase`, or anything with an +:meth:`apply` method), it ensures that the number of workers wasn't specified +twice (during backend initialization and in :meth:`run` arguments). + +After a backend is configured, :meth:`_setup_computation_groups` splits the +frames prepared earlier in :attr:`self._prepare_sliced_trajectory` into a +number of groups, by default equal to the number of workers. + +In the :meth:`_compute` method, frames get initialized again with +:meth:`_prepare_sliced_trajectory`, and attributes necessary for a specific +analysis get initialized with the :meth:`_prepare` method. Then the function +iterates over :attr:`self._sliced_trajectory`, assigning +:attr:`self._frame_index` and :attr:`self._ts` as frame index (within a +computation group) and timestamp, and also setting respective +:attr:`self.frames` and :attr:`self.times` array values. + +After :meth:`_compute` has finished, the main analysis instance calls the +:meth:`_get_aggregator` method, which merges the :attr:`self.results` +attributes from other processes into a single +:class:`MDAnalysis.analysis.results.Results` instance, making it look for the +subsequent :meth:`_conclude` method as if the run was performed in a serial +fashion, without parallelization. + + +Helper classes for parallelization +================================== + +``ResultsGroup`` +---------------- + +:class:`MDAnalysis.analysis.results.ResultsGroup` extends the functionality of +the :class:`MDAnalysis.analysis.results.Results` class. Since the ``Results`` +class is basically a dictionary that also keeps track of assigned attributes, it +is possible to iterate over all these attributes later. ``ResultsGroup`` does +exactly that: given a list of the ``Results`` objects with the same attributes, +it applies a specific aggregation function to every attribute, and stores it as +a same attribute of the returned object: + +.. code-block:: python + + from MDAnalysis.analysis.results import ResultsGroup, Results + group = ResultsGroup(lookup={'mass': ResultsGroup.float_mean}) + obj1 = Results(mass=1) + obj2 = Results(mass=3) + assert group.merge([obj1, obj2]) == Results(mass=2.0) + + +``BackendBase`` +--------------- + +:class:`MDAnalysis.analysis.backends.BackendBase` holds all backend attributes, +and also implements an :meth:`MDAnalysis.analysis.backends.BackendBase.apply` +method, applying a given function to a list of its parameters, but in a parallel +fashion. Although in ``AnalysisBase`` it is used to apply a ``_compute`` +function, in principle it can be used to any arbitrary function and arguments, +given they're serializable. + + +When to use parallelization? (Known limitations) +================================================ + +For now, the syntax for running parallel analysis is explicit, meaning by +default the ``serial`` version will be run, and the parallelization won't be +enabled by default. Although we expect the parallelization to be useful in most +cases, there are some known caveats from the inital benchmarks. + +Fast ``_single_frame`` compared to reading from disk +---------------------------------------------------- + +In all cases, parallelization will not be useful only when frames are being +processed faster than being read from the disk, otherwise reading is the +bottleneck here. Hence, you'll benefit from parallelization only if you have +relatively much compute per frame, or a fast drive, as illustrated below: + +.. figure:: /images/parallelization_time.png + +In other words, if you have *fast* analysis (say, +:class:`MDAnalysis.analysis.rms.RMSD`) **and** a slow HDD drive, you are likely +to not get any benefits from parallelization. Otherwise, you should be fine. + +Serialization issues +-------------------- + +For built-in analysis classes, the default serialization with both +:mod:`multiprocessing` and :mod:`dask` is known to work. If you're using some custom +analysis class that e.g. stores a non-serializable object in one of its +attributes, you might get a serialization error (:exc:`PicklingError` if you're +using a ``multiprocessing`` backend). If you want to get around that, we suggest +trying ``backend='dask'`` (it uses ``dask`` serialization engine instead of +:mod:`pickle`). + +Out of memory issues +-------------------- + +If you have large memory footprint of each worker, you can run into +out-of-memory errors (i.e. your server freezes when executing a run). In this +case we suggest decreasing the number of workers from all available CPUs (that +you can get with :func:`multiprocessing.cpu_count`) to a smaller number. + +Progress bar is missing +----------------------- + +It is yet not possible to get a progress bar running with any parallel backend. +If you want an ETA of your analysis, we suggest running it in ``serial`` mode +for the first 10-100 frames with ``verbose=True``, and then running it with +multiple workers. Processing time scales almost linearly, so you can get your +ETA by dividing ``serial`` ETA by the number of workers. + + +.. _adding-parallelization: + +Adding parallelization to your own analysis class +================================================= + +If you want to add parallelization to your own analysis class, first make sure +your algorithm allows you to do that, i.e. you can process each frame independently. +Then it's rather simple -- let's look at the actual code that added +parallelization to the :class:`MDAnalysis.analysis.rms.RMSD`: + +.. code-block:: python + + from MDAnalysis.analysis.base import AnalysisBase + from MDAnalysis.analysis.results import ResultsGroup + + class RMSD(BackendBase): + @classmethod + def get_supported_backends(cls): + return ('serial', 'multiprocessing', 'dask',) + + _analysis_algorithm_is_parallelizable = True + + def _get_aggregator(self): + return ResultsGroup(lookup={'rmsd': ResultsGroup.ndarray_vstack}) + + +That's it! The first two methods are boilerplate -- +:meth:`get_supported_backends` returns a tuple with built-in backends that will +work for your class (if there are no serialization issues, it should be all +three), and ``_is_parallelizable`` is ``True`` (which is set to ``False`` in +``AnalysisBase``, hence we have to re-define it), and :meth:`_get_aggregator` +will be used as described earlier. Note that :mod:`MDAnalysis.analysis.results` +also provides a few convenient functions (defined as class methods of +:class:`~MDAnalysis.analysis.results.ResultsGroup`) for results aggregation: + +#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.flatten_sequence` +#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.ndarray_sum` +#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.ndarray_mean` +#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.float_mean` +#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.ndarray_hstack` +#. :meth:`~MDAnalysis.analysis.results.ResultsGroup.ndarray_vstack` + + +So you'll likely find appropriate functions for basic aggregation there. + +Writing custom backends +======================= + +In order to write your custom backend (e.g. using :mod:`dask.distributed`), inherit +from the :class:`MDAnalysis.analysis.backends.BackendBase` and (re)-implement +:meth:`__init__` and :meth:`apply` methods. Optionally, you can implement methods for +validation of correct backend initialization -- :meth:`_get_checks` and +:meth:`_get_warnings`, that would raise an exception or give a warning, respectively, +when a new class instance is created: + +#. :meth:`MDAnalysis.analysis.backends._get_checks` +#. :meth:`MDAnalysis.analysis.backends._get_warnings` + +.. code-block:: python + + from MDAnalysis.analysis.backends import BackendBase + class ThreadsBackend(BackendBase): + def __init__(self, n_workers: int, starting_message: str = "Useless backend"): + self.n_workers = n_workers + self.starting_message = starting_message + self._validate() + + def _get_warnings(self): + return {True: 'warning: this backend is useless'} + + def _get_checks(self): + return {isinstance(self.n_workers, int), 'error: self.n_workers is not an integer'} + + def apply(self, func, computations): + from multiprocessing.dummy import Pool + + with Pool(processes=self.n_workers) as pool: + print(self.starting_message) + results = pool.map(func, computations) + return results + + +In order to use a custom backend with another analysis class that does not +explicitly support it, you must *explicitly state* that you're about to use an +unsupported_backend by passing the keyword argument +``unsupported_backend=True``: + +.. code-block:: python + + from MDAnalysis.analysis.rms import RMSD + R = RMSD(...) # setup the run + n_workers = 2 + backend = ThreadsBackend(n_workers=n_workers) + R.run(backend=backend, unsupported_backend=True) + +In this way, you will override the check for supported backends. + +.. Warning:: + + When you use ``unsupported_backend=True`` you should make sure that you get + the same results as when using a supported backend for which the analysis + class was tested. + + Before reporting a problem with an analysis class, make sure you tested it + with a supported backend. When reporting *always mention if you used* + ``unsupported_backend=True``. + + +.. rubric:: References +.. footbibliography:: + diff --git a/package/doc/sphinx/source/documentation_pages/analysis/results.rst b/package/doc/sphinx/source/documentation_pages/analysis/results.rst new file mode 100644 index 00000000000..22fc9fd81f3 --- /dev/null +++ b/package/doc/sphinx/source/documentation_pages/analysis/results.rst @@ -0,0 +1,4 @@ +.. automodule:: MDAnalysis.analysis.results + :members: + + diff --git a/package/doc/sphinx/source/documentation_pages/analysis_modules.rst b/package/doc/sphinx/source/documentation_pages/analysis_modules.rst index ee7f6568758..ec484b0382d 100644 --- a/package/doc/sphinx/source/documentation_pages/analysis_modules.rst +++ b/package/doc/sphinx/source/documentation_pages/analysis_modules.rst @@ -1,23 +1,42 @@ .. Contains the formatted docstrings from the analysis modules located in 'mdanalysis/MDAnalysis/analysis', although in some cases the documentation imports functions and docstrings from other files which are also curated to reStructuredText markup. +.. module:: MDAnalysis.analysis + **************** Analysis modules **************** The :mod:`MDAnalysis.analysis` module contains code to carry out specific -analysis functionality for MD trajectories. -It is based on the core functionality (i.e. trajectory -I/O, selections etc). The analysis modules can be used as examples for how to -use MDAnalysis but also as working code for research projects; typically all -contributed code has been used by the authors in their own work. -An analysis using the available modules -usually follows the same structure +analysis functionality for MD trajectories. It is based on the core +functionality (i.e. trajectory I/O, selections etc). The analysis modules can +be used as examples for how to use MDAnalysis but also as working code for +research projects; typically all contributed code has been used by the authors +in their own work. + +Getting started with analysis +============================= + +.. SeeAlso:: + + The `User Guide: Analysis`_ contains extensive documentation of the analysis + capabilities with user-friendly examples. + +.. _`User Guide: Analysis`: + https://userguide.mdanalysis.org/stable/examples/analysis/README.html + +Using the analysis classes +-------------------------- + +Most analysis tools in MDAnalysis are written as a single class. An analysis +usually follows the same pattern: #. Import the desired module, since analysis modules are not imported by default. #. Initialize the analysis class instance from the previously imported module. -#. Run the analysis, optionally for specific trajectory slices. -#. Access the analysis from the :attr:`results` attribute +#. Run the analysis with the :meth:`~MDAnalysis.analysis.base.AnalysisBase.run` + method, optionally for specific trajectory slices. +#. Access the analysis from the + :attr:`~MDAnalysis.analysis.base.AnalysisBase.results` attribute .. code-block:: python @@ -31,28 +50,81 @@ usually follows the same structure Please see the individual module documentation for any specific caveats and also read and cite the reference papers associated with these algorithms. -.. rubric:: Additional dependencies + +Using parallelization for built-in analysis runs +------------------------------------------------ + +.. versionadded:: 2.8.0 + +:class:`~MDAnalysis.analysis.base.AnalysisBase` subclasses can run on a backend +that supports parallelization (see :mod:`MDAnalysis.analysis.backends`). All +analysis runs use ``backend='serial'`` by default, i.e., they do not use +parallelization by default, which has been standard before release 2.8.0 of +MDAnalysis. + +Without any dependencies, only one backend is supported -- built-in +:mod:`multiprocessing`, that processes parts of a trajectory running separate +*processes*, i.e. utilizing multi-core processors properly. + +.. Note:: + + For now, parallelization has only been added to + :class:`MDAnalysis.analysis.rms.RMSD`, but by release 3.0 version it will be + introduced to all subclasses that can support it. + +In order to use that feature, simply add ``backend='multiprocessing'`` to your +run, and supply it with proper ``n_workers`` (use ``multiprocessing.cpu_count()`` +for maximum available on your machine): + +.. code-block:: python + + import multiprocessing + import MDAnalysis as mda + from MDAnalysisTests.datafiles import PSF, DCD + from MDAnalysis.analysis.rms import RMSD + from MDAnalysis.analysis.align import AverageStructure + + # initialize the universe + u = mda.Universe(PSF, DCD) + + # calculate average structure for reference + avg = AverageStructure(mobile=u).run() + ref = avg.results.universe + + # initialize RMSD run + rmsd = RMSD(u, ref, select='backbone') + rmsd.run(backend='multiprocessing', n_workers=multiprocessing.cpu_count()) + +For now, you have to be explicit and specify both ``backend`` and ``n_workers``, +since the feature is new and there are no good defaults for it. For example, +if you specify a too big `n_workers`, and your trajectory frames are big, +you might get and out-of-memory error when executing your run. + +You can also implement your own backends -- see :mod:`MDAnalysis.analysis.backends`. + + +Additional dependencies +----------------------- Some of the modules in :mod:`MDAnalysis.analysis` require additional Python packages to enable full functionality. For example, :mod:`MDAnalysis.analysis.encore` provides more options if `scikit-learn`_ is -installed. If you installed MDAnalysis with -:program:`pip` (see :ref:`installation-instructions`) -these packages are *not automatically installed*. -Although, one can add the ``[analysis]`` tag to the -:program:`pip` command to force their installation. If you installed -MDAnalysis with :program:`conda` then a -*full set of dependencies* is automatically installed. +installed. If you installed MDAnalysis with :program:`pip` (see +:ref:`installation-instructions`) these packages are *not automatically +installed* although one can add the ``[analysis]`` tag to the :program:`pip` +command to force their installation. If you installed MDAnalysis with +:program:`conda` then a *full set of dependencies* is automatically installed. Other modules require external programs. For instance, the -:mod:`MDAnalysis.analysis.hole2.hole` module requires an installation of the -HOLE_ suite of programs. You will need to install these external dependencies -by following their installation instructions before you can use the -corresponding MDAnalysis module. +:mod:`MDAnalysis.analysis.hole2` module requires an installation of the HOLE_ +suite of programs. You will need to install these external dependencies by +following their installation instructions before you can use the corresponding +MDAnalysis module. .. _scikit-learn: http://scikit-learn.org/ .. _HOLE: http://www.holeprogram.org/ + Building blocks for Analysis ============================ @@ -64,6 +136,9 @@ To build your own analysis class start by reading the documentation. :maxdepth: 1 analysis/base + analysis/backends + analysis/results + analysis/parallelization Distances and contacts ====================== diff --git a/package/doc/sphinx/source/images/AnalysisBase_parallel.png b/package/doc/sphinx/source/images/AnalysisBase_parallel.png new file mode 100644 index 00000000000..960e7cc257b Binary files /dev/null and b/package/doc/sphinx/source/images/AnalysisBase_parallel.png differ diff --git a/package/doc/sphinx/source/images/parallelization_time.png b/package/doc/sphinx/source/images/parallelization_time.png new file mode 100644 index 00000000000..09e584555ac Binary files /dev/null and b/package/doc/sphinx/source/images/parallelization_time.png differ diff --git a/package/doc/sphinx/source/references.bib b/package/doc/sphinx/source/references.bib index 5e2167c9867..8fe33a1b64d 100644 --- a/package/doc/sphinx/source/references.bib +++ b/package/doc/sphinx/source/references.bib @@ -795,3 +795,15 @@ @article{Linke2018 pages = {5630--5639}, doi = {10.1021/acs.jpcb.7b11988} } + +@inproceedings{Fan2019, + title = {{PMDA} - {P}arallel {M}olecular {D}ynamics {A}nalysis}, + author = {Shujie Fan and Max Linke and Ioannis Paraskevakos and Richard J. Gowers and Michael Gecht and Oliver Beckstein}, + year = {2019}, + booktitle = {{P}roceedings of the 18th {P}ython in {S}cience {C}onference}, + editor = {{C}hris {C}alloway and {D}avid {L}ippa and {D}illon {N}iederhut and {D}avid {S}hupe}, + organization = {SciPy}, + address = {Austin, TX}, + pages = {134 - 142}, + doi = {10.25080/Majora-7ddc1dd1-013} +} diff --git a/package/pyproject.toml b/package/pyproject.toml index 1880d88fd0d..05b0eb3ba16 100644 --- a/package/pyproject.toml +++ b/package/pyproject.toml @@ -94,6 +94,9 @@ doc = [ "pybtex", "pybtex-docutils", ] +parallel = [ + "dask", +] [project.urls] Homepage = "https://www.mdanalysis.org" diff --git a/testsuite/MDAnalysisTests/analysis/conftest.py b/testsuite/MDAnalysisTests/analysis/conftest.py new file mode 100644 index 00000000000..55bae7e6bd8 --- /dev/null +++ b/testsuite/MDAnalysisTests/analysis/conftest.py @@ -0,0 +1,89 @@ +import pytest + +from MDAnalysis.analysis.base import AnalysisBase, AnalysisFromFunction +from MDAnalysisTests.analysis.test_base import ( + FrameAnalysis, + IncompleteAnalysis, + OldAPIAnalysis, +) +from MDAnalysis.analysis.rms import RMSD, RMSF +from MDAnalysis.lib.util import is_installed + + +def params_for_cls(cls, exclude: list[str] = None): + """ + This part contains fixtures for simultaneous testing + of all available (=installed & supported) backends + for analysis subclasses. + + If for some reason you want to limit backends, + simply pass "exclude: list[str]" to the function + that parametrizes fixture. + + Parameters + ---------- + exclude : list[str], optional + list of backends to exclude from parametrization, by default None + + Returns + ------- + dict + dictionary with all tested keyword combinations for the run + """ + exclude = [] if exclude is None else exclude + possible_backends = cls.get_supported_backends() + installed_backends = [ + b for b in possible_backends if is_installed(b) and b not in exclude + ] + + params = [ + pytest.param({ + "backend": backend, + "n_workers": nproc + }, ) for backend in installed_backends for nproc in (2, ) + if backend != "serial" + ] + params.extend([{"backend": "serial"}]) + return params + + +@pytest.fixture(scope='module', params=params_for_cls(FrameAnalysis)) +def client_FrameAnalysis(request): + return request.param + + +@pytest.fixture(scope='module', params=params_for_cls(AnalysisBase)) +def client_AnalysisBase(request): + return request.param + + +@pytest.fixture(scope='module', params=params_for_cls(AnalysisFromFunction)) +def client_AnalysisFromFunction(request): + return request.param + + +@pytest.fixture(scope='module', + params=params_for_cls(AnalysisFromFunction, + exclude=['multiprocessing'])) +def client_AnalysisFromFunctionAnalysisClass(request): + return request.param + + +@pytest.fixture(scope='module', params=params_for_cls(IncompleteAnalysis)) +def client_IncompleteAnalysis(request): + return request.param + + +@pytest.fixture(scope='module', params=params_for_cls(OldAPIAnalysis)) +def client_OldAPIAnalysis(request): + return request.param + + +@pytest.fixture(scope='module', params=params_for_cls(RMSD)) +def client_RMSD(request): + return request.param + + +@pytest.fixture(scope='module', params=params_for_cls(RMSF)) +def client_RMSF(request): + return request.param diff --git a/testsuite/MDAnalysisTests/analysis/test_backends.py b/testsuite/MDAnalysisTests/analysis/test_backends.py new file mode 100644 index 00000000000..a4c105e082a --- /dev/null +++ b/testsuite/MDAnalysisTests/analysis/test_backends.py @@ -0,0 +1,72 @@ +import pytest +from MDAnalysis.analysis import backends +from MDAnalysis.lib.util import is_installed + + +def square(x: int): + return x**2 + + +def noop(x): + return x + + +def upper(s): + return s.upper() + + +class Test_Backends: + + @pytest.mark.parametrize( + "backend_cls,n_workers", + [ + (backends.BackendBase, -1), + (backends.BackendSerial, None), + (backends.BackendMultiprocessing, "string"), + (backends.BackendDask, ()), + ], + ) + def test_fails_incorrect_n_workers(self, backend_cls, n_workers): + with pytest.raises(ValueError): + _ = backend_cls(n_workers=n_workers) + + @pytest.mark.parametrize( + "func,iterable,answer", + [ + (square, (1, 2, 3), [1, 4, 9]), + (square, (), []), + (noop, list(range(10)), list(range(10))), + (upper, "asdf", list("ASDF")), + ], + ) + def test_all_backends_give_correct_results(self, func, iterable, answer): + backend_instances = [ + backends.BackendMultiprocessing(n_workers=2), + backends.BackendSerial(n_workers=1), + ] + if is_installed("dask"): + backend_instances.append(backends.BackendDask(n_workers=2)) + + backends_dict = {b: b.apply(func, iterable) for b in backend_instances} + for answ in backends_dict.values(): + assert answ == answer + + @pytest.mark.parametrize("backend_cls,params,warning_message", [ + (backends.BackendSerial, { + 'n_workers': 5 + }, "n_workers is ignored when executing with backend='serial'"), + ]) + def test_get_warnings(self, backend_cls, params, warning_message): + with pytest.warns(UserWarning, match=warning_message): + backend_cls(**params) + + @pytest.mark.parametrize("backend_cls,params,error_message", [ + pytest.param(backends.BackendDask, {'n_workers': 2}, + ("module 'dask' is missing. Please install 'dask': " + "https://docs.dask.org/en/stable/install.html"), + marks=pytest.mark.skipif(is_installed('dask'), + reason='dask is installed')) + ]) + def test_get_errors(self, backend_cls, params, error_message): + with pytest.raises(ValueError, match=error_message): + backend_cls(**params) diff --git a/testsuite/MDAnalysisTests/analysis/test_base.py b/testsuite/MDAnalysisTests/analysis/test_base.py index 84cd4c5b9c2..6e6b7921960 100644 --- a/testsuite/MDAnalysisTests/analysis/test_base.py +++ b/testsuite/MDAnalysisTests/analysis/test_base.py @@ -30,134 +30,37 @@ from numpy.testing import assert_equal, assert_allclose import MDAnalysis as mda -from MDAnalysis.analysis import base - -from MDAnalysisTests.datafiles import PSF, DCD, TPR, XTC +import numpy as np +import pytest +from MDAnalysis.analysis import base, backends +from MDAnalysisTests.datafiles import DCD, PSF, TPR, XTC from MDAnalysisTests.util import no_deprecated_call - - -class Test_Results: - - @pytest.fixture - def results(self): - return base.Results(a=1, b=2) - - def test_get(self, results): - assert results.a == results["a"] == 1 - - def test_no_attr(self, results): - msg = "'Results' object has no attribute 'c'" - with pytest.raises(AttributeError, match=msg): - results.c - - def test_set_attr(self, results): - value = [1, 2, 3, 4] - results.c = value - assert results.c is results["c"] is value - - def test_set_key(self, results): - value = [1, 2, 3, 4] - results["c"] = value - assert results.c is results["c"] is value - - @pytest.mark.parametrize('key', dir(UserDict) + ["data"]) - def test_existing_dict_attr(self, results, key): - msg = f"'{key}' is a protected dictionary attribute" - with pytest.raises(AttributeError, match=msg): - results[key] = None - - @pytest.mark.parametrize('key', dir(UserDict) + ["data"]) - def test_wrong_init_type(self, key): - msg = f"'{key}' is a protected dictionary attribute" - with pytest.raises(AttributeError, match=msg): - base.Results(**{key: None}) - - @pytest.mark.parametrize('key', ("0123", "0j", "1.1", "{}", "a b")) - def test_weird_key(self, results, key): - msg = f"'{key}' is not a valid attribute" - with pytest.raises(ValueError, match=msg): - results[key] = None - - def test_setattr_modify_item(self, results): - mylist = [1, 2] - mylist2 = [3, 4] - results.myattr = mylist - assert results.myattr is mylist - results["myattr"] = mylist2 - assert results.myattr is mylist2 - mylist2.pop(0) - assert len(results.myattr) == 1 - assert results.myattr is mylist2 - - def test_setitem_modify_item(self, results): - mylist = [1, 2] - mylist2 = [3, 4] - results["myattr"] = mylist - assert results.myattr is mylist - results.myattr = mylist2 - assert results.myattr is mylist2 - mylist2.pop(0) - assert len(results["myattr"]) == 1 - assert results["myattr"] is mylist2 - - def test_delattr(self, results): - assert hasattr(results, "a") - delattr(results, "a") - assert not hasattr(results, "a") - - def test_missing_delattr(self, results): - assert not hasattr(results, "d") - msg = "'Results' object has no attribute 'd'" - with pytest.raises(AttributeError, match=msg): - delattr(results, "d") - - def test_pop(self, results): - assert hasattr(results, "a") - results.pop("a") - assert not hasattr(results, "a") - - def test_update(self, results): - assert not hasattr(results, "spudda") - results.update({"spudda": "fett"}) - assert results.spudda == "fett" - - def test_update_data_fail(self, results): - msg = f"'data' is a protected dictionary attribute" - with pytest.raises(AttributeError, match=msg): - results.update({"data": 0}) - - def test_pickle(self, results): - results_p = pickle.dumps(results) - results_new = pickle.loads(results_p) - - @pytest.mark.parametrize("args, kwargs, length", [ - (({"darth": "tater"},), {}, 1), - ([], {"darth": "tater"}, 1), - (({"darth": "tater"},), {"yam": "solo"}, 2), - (({"darth": "tater"},), {"darth": "vader"}, 1), - ]) - def test_initialize_arguments(self, args, kwargs, length): - results = base.Results(*args, **kwargs) - ref = dict(*args, **kwargs) - assert ref == results - assert len(results) == length - - def test_different_instances(self, results): - new_results = base.Results(darth="tater") - assert new_results.data is not results.data +from numpy.testing import assert_almost_equal, assert_equal class FrameAnalysis(base.AnalysisBase): """Just grabs frame numbers of frames it goes over""" + @classmethod + def get_supported_backends(cls): return ('serial', 'dask', 'multiprocessing') + + _analysis_algorithm_is_parallelizable = True + def __init__(self, reader, **kwargs): super(FrameAnalysis, self).__init__(reader, **kwargs) self.traj = reader - self.found_frames = [] + + def _prepare(self): + self.results.found_frames = [] def _single_frame(self): - self.found_frames.append(self._ts.frame) + self.results.found_frames.append(self._ts.frame) + + def _conclude(self): + self.found_frames = list(self.results.found_frames) + def _get_aggregator(self): + return base.ResultsGroup({'found_frames': base.ResultsGroup.ndarray_hstack}) class IncompleteAnalysis(base.AnalysisBase): def __init__(self, reader, **kwargs): @@ -173,6 +76,9 @@ def __init__(self, reader, **kwargs): def _single_frame(self): pass + def _prepare(self): + self.results = base.Results() + @pytest.fixture(scope='module') def u(): @@ -187,6 +93,110 @@ def u_xtc(): FRAMES_ERR = 'AnalysisBase.frames is incorrect' TIMES_ERR = 'AnalysisBase.times is incorrect' +class Parallelizable(base.AnalysisBase): + _analysis_algorithm_is_parallelizable = True + @classmethod + def get_supported_backends(cls): return ('multiprocessing', 'dask') + def _single_frame(self): pass + +class SerialOnly(base.AnalysisBase): + def _single_frame(self): pass + +class ParallelizableWithDaskOnly(base.AnalysisBase): + _analysis_algorithm_is_parallelizable = True + @classmethod + def get_supported_backends(cls): return ('dask',) + def _single_frame(self): pass + +class CustomSerialBackend(backends.BackendBase): + def apply(self, func, computations): + return [func(task) for task in computations] + +class ManyWorkersBackend(backends.BackendBase): + def apply(self, func, computations): + return [func(task) for task in computations] + +def test_incompatible_n_workers(u): + backend = ManyWorkersBackend(n_workers=2) + with pytest.raises(ValueError): + FrameAnalysis(u).run(backend=backend, n_workers=3) + +@pytest.mark.parametrize('run_class,backend,n_workers', [ + (Parallelizable, 'not-existing-backend', 2), + (Parallelizable, 'not-existing-backend', None), + (SerialOnly, 'not-existing-backend', 2), + (SerialOnly, 'not-existing-backend', None), + (SerialOnly, 'multiprocessing', 2), + (SerialOnly, 'dask', None), + (ParallelizableWithDaskOnly, 'multiprocessing', None), + (ParallelizableWithDaskOnly, 'multiprocessing', 2), +]) +def test_backend_configuration_fails(u, run_class, backend, n_workers): + u = mda.Universe(TPR, XTC) # dt = 100 + with pytest.raises(ValueError): + _ = run_class(u.trajectory).run(backend=backend, n_workers=n_workers, stop=0) + +@pytest.mark.parametrize('run_class,backend,n_workers', [ + (Parallelizable, CustomSerialBackend, 2), + (ParallelizableWithDaskOnly, CustomSerialBackend, 2), +]) +def test_backend_configuration_works_when_unsupported_backend(u, run_class, backend, n_workers): + u = mda.Universe(TPR, XTC) # dt = 100 + backend_instance = backend(n_workers=n_workers) + _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers, stop=0, unsupported_backend=True) + +@pytest.mark.parametrize('run_class,backend,n_workers', [ + (Parallelizable, CustomSerialBackend, 1), + (ParallelizableWithDaskOnly, CustomSerialBackend, 1), +]) +def test_custom_backend_works(u, run_class, backend, n_workers): + backend_instance = backend(n_workers=n_workers) + u = mda.Universe(TPR, XTC) # dt = 100 + _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers, unsupported_backend=True) + +@pytest.mark.parametrize('run_class,backend_instance,n_workers', [ + (Parallelizable, map, 1), + (SerialOnly, list, 1), + (ParallelizableWithDaskOnly, object, 1), +]) +def test_fails_incorrect_custom_backend(u, run_class, backend_instance, n_workers): + u = mda.Universe(TPR, XTC) # dt = 100 + with pytest.raises(ValueError): + _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers, unsupported_backend=True) + + with pytest.raises(ValueError): + _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers) + +@pytest.mark.parametrize('run_class,backend,n_workers', [ + (SerialOnly, CustomSerialBackend, 1), + (SerialOnly, 'multiprocessing', 1), + (SerialOnly, 'dask', 1), +]) +def test_fails_for_unparallelizable(u, run_class, backend, n_workers): + u = mda.Universe(TPR, XTC) # dt = 100 + with pytest.raises(ValueError): + if not isinstance(backend, str): + backend_instance = backend(n_workers=n_workers) + _ = run_class(u.trajectory).run(backend=backend_instance, n_workers=n_workers, unsupported_backend=True) + else: + _ = run_class(u.trajectory).run(backend=backend, n_workers=n_workers, unsupported_backend=True) + +@pytest.mark.parametrize('run_kwargs,frames', [ + ({}, np.arange(98)), + ({'start': 20}, np.arange(20, 98)), + ({'stop': 30}, np.arange(30)), + ({'step': 10}, np.arange(0, 98, 10)) +]) +def test_start_stop_step_parallel(u, run_kwargs, frames, client_FrameAnalysis): + # client_FrameAnalysis is defined [here](testsuite/MDAnalysisTests/analysis/conftest.py), + # and determines a set of parameters ('backend', 'n_workers'), taking only backends + # that are implemented for a given subclass, to run the test against. + an = FrameAnalysis(u.trajectory).run(**run_kwargs, **client_FrameAnalysis) + assert an.n_frames == len(frames) + assert_equal(an.found_frames, frames) + assert_equal(an.frames, frames, err_msg=FRAMES_ERR) + assert_almost_equal(an.times, frames+1, decimal=4, err_msg=TIMES_ERR) + @pytest.mark.parametrize('run_kwargs,frames', [ ({}, np.arange(98)), @@ -217,6 +227,22 @@ def test_frame_slice(u_xtc, run_kwargs, frames): assert_equal(an.frames, frames, err_msg=FRAMES_ERR) +@pytest.mark.parametrize('run_kwargs, frames', [ + ({'frames': [4, 5, 6, 7, 8, 9]}, np.arange(4, 10)), + ({'frames': [0, 2, 4, 6, 8]}, np.arange(0, 10, 2)), + ({'frames': [4, 6, 8]}, np.arange(4, 10, 2)), + ({'frames': [0, 3, 4, 3, 5]}, [0, 3, 4, 3, 5]), + ({'frames': [True, True, False, True, False, True, True, False, True, + False]}, (0, 1, 3, 5, 6, 8)), +]) +def test_frame_slice_parallel(run_kwargs, frames, client_FrameAnalysis): + u = mda.Universe(TPR, XTC) # dt = 100 + an = FrameAnalysis(u.trajectory).run(**run_kwargs, **client_FrameAnalysis) + assert an.n_frames == len(frames) + assert_equal(an.found_frames, frames) + assert_equal(an.frames, frames, err_msg=FRAMES_ERR) + + @pytest.mark.parametrize('run_kwargs', [ ({'start': 4, 'frames': [4, 5, 6, 7, 8, 9]}), ({'stop': 6, 'frames': [0, 1, 2, 3, 4, 5]}), @@ -226,28 +252,44 @@ def test_frame_slice(u_xtc, run_kwargs, frames): ({'start': 4, 'step': 2, 'frames': [4, 6, 8]}), ({'start': 0, 'stop': 0, 'step': 0, 'frames': [4, 6, 8]}), ]) -def test_frame_fail(u, run_kwargs): +def test_frame_fail(u, run_kwargs, client_FrameAnalysis): an = FrameAnalysis(u.trajectory) msg = 'start/stop/step cannot be combined with frames' with pytest.raises(ValueError, match=msg): - an.run(**run_kwargs) + an.run(**client_FrameAnalysis, **run_kwargs) + +def test_parallelizable_transformations(): + # pick any transformation that would allow + # for parallelizable attribute + from MDAnalysis.transformations import NoJump + u = mda.Universe(XTC) + u.trajectory.add_transformations(NoJump()) + # test that serial works + FrameAnalysis(u.trajectory).run() + + # test that parallel fails + with pytest.raises(ValueError): + FrameAnalysis(u.trajectory).run(backend='multiprocessing') -def test_frame_bool_fail(u_xtc): - an = FrameAnalysis(u_xtc.trajectory) +def test_frame_bool_fail(client_FrameAnalysis): + u = mda.Universe(TPR, XTC) # dt = 100 + an = FrameAnalysis(u.trajectory) frames = [True, True, False] msg = 'boolean index did not match indexed array along (axis|dimension) 0' with pytest.raises(IndexError, match=msg): - an.run(frames=frames) + an.run(**client_FrameAnalysis, frames=frames) -def test_rewind(u_xtc): - FrameAnalysis(u_xtc.trajectory).run(frames=[0, 2, 3, 5, 9]) - assert_equal(u_xtc.trajectory.ts.frame, 0) +def test_rewind(client_FrameAnalysis): + u = mda.Universe(TPR, XTC) # dt = 100 + an = FrameAnalysis(u.trajectory).run(**client_FrameAnalysis, frames=[0, 2, 3, 5, 9]) + assert_equal(u.trajectory.ts.frame, 0) -def test_frames_times(u_xtc): - an = FrameAnalysis(u_xtc.trajectory).run(start=1, stop=8, step=2) +def test_frames_times(client_FrameAnalysis): + u = mda.Universe(TPR, XTC) # dt = 100 + an = FrameAnalysis(u.trajectory).run(start=1, stop=8, step=2, **client_FrameAnalysis) frames = np.array([1, 3, 5, 7]) assert an.n_frames == len(frames) assert_equal(an.found_frames, frames) @@ -260,6 +302,24 @@ def test_verbose(u): assert a._verbose +def test_warn_nparts_nworkers(u): + a = FrameAnalysis(u.trajectory) + with pytest.warns(UserWarning): + a.run(backend='multiprocessing', n_workers=3, n_parts=2) + + +@pytest.mark.parametrize( + "classname,is_parallelizable", + [ + (base.AnalysisBase, False), + (base.AnalysisFromFunction, True), + (FrameAnalysis, True) + ] +) +def test_not_parallelizable(u, classname, is_parallelizable): + assert classname._analysis_algorithm_is_parallelizable == is_parallelizable + + def test_verbose_progressbar(u, capsys): FrameAnalysis(u.trajectory).run() _, err = capsys.readouterr() @@ -283,6 +343,12 @@ def test_verbose_progressbar_run_with_kwargs(u, capsys): actual = err.strip().split('\r')[-1] assert actual[:30] == expected[:30] + +def test_progressbar_multiprocessing(u): + with pytest.raises(ValueError): + FrameAnalysis(u.trajectory).run(backend='multiprocessing', verbose=True) + + def test_incomplete_defined_analysis(u): with pytest.raises(NotImplementedError): IncompleteAnalysis(u.trajectory).run() @@ -332,15 +398,18 @@ def test_results_type(u): (20, 50, 2, 15), (20, 50, None, 30) ]) -def test_AnalysisFromFunction(u, start, stop, step, nframes): +def test_AnalysisFromFunction(u, start, stop, step, nframes, client_AnalysisFromFunction): + # client_AnalysisFromFunction is defined [here](testsuite/MDAnalysisTests/analysis/conftest.py), + # and determines a set of parameters ('backend', 'n_workers'), taking only backends + # that are implemented for a given subclass, to run the test against. ana1 = base.AnalysisFromFunction(simple_function, mobile=u.atoms) - ana1.run(start=start, stop=stop, step=step) + ana1.run(start=start, stop=stop, step=step, **client_AnalysisFromFunction) ana2 = base.AnalysisFromFunction(simple_function, u.atoms) - ana2.run(start=start, stop=stop, step=step) + ana2.run(start=start, stop=stop, step=step, **client_AnalysisFromFunction) ana3 = base.AnalysisFromFunction(simple_function, u.trajectory, u.atoms) - ana3.run(start=start, stop=stop, step=step) + ana3.run(start=start, stop=stop, step=step, **client_AnalysisFromFunction) frames = [] times = [] @@ -367,26 +436,27 @@ def mass_xyz(atomgroup1, atomgroup2, masses): return atomgroup1.positions * masses -def test_AnalysisFromFunction_args_content(u): +def test_AnalysisFromFunction_args_content(u, client_AnalysisFromFunction): protein = u.select_atoms('protein') masses = protein.masses.reshape(-1, 1) another = mda.Universe(TPR, XTC).select_atoms("protein") ans = base.AnalysisFromFunction(mass_xyz, protein, another, masses) assert len(ans.args) == 3 - result = np.sum(ans.run().results.timeseries) + result = np.sum(ans.run(**client_AnalysisFromFunction).results.timeseries) assert_allclose(result, -317054.67757345125, rtol=0, atol=1.5e-6) + assert_almost_equal(result, -317054.67757345125, decimal=6) assert (ans.args[0] is protein) and (ans.args[1] is another) assert ans._trajectory is protein.universe.trajectory -def test_analysis_class(): +def test_analysis_class(client_AnalysisFromFunctionAnalysisClass): ana_class = base.analysis_class(simple_function) assert issubclass(ana_class, base.AnalysisBase) assert issubclass(ana_class, base.AnalysisFromFunction) u = mda.Universe(PSF, DCD) step = 2 - ana = ana_class(u.atoms).run(step=step) + ana = ana_class(u.atoms).run(step=step, **client_AnalysisFromFunctionAnalysisClass) results = [] for ts in u.trajectory[::step]: diff --git a/testsuite/MDAnalysisTests/analysis/test_pca.py b/testsuite/MDAnalysisTests/analysis/test_pca.py index c72971c30ed..ec874b900fe 100644 --- a/testsuite/MDAnalysisTests/analysis/test_pca.py +++ b/testsuite/MDAnalysisTests/analysis/test_pca.py @@ -133,6 +133,12 @@ def test_no_frames(u): PCA(u, select=SELECTION).run(stop=1) +def test_can_run_frames(u): + atoms = u.select_atoms(SELECTION) + u.transfer_to_memory() + PCA(u, select=SELECTION).run(frames=[0,1]) + + def test_can_run_frames(u): atoms = u.select_atoms(SELECTION) u.transfer_to_memory() diff --git a/testsuite/MDAnalysisTests/analysis/test_results.py b/testsuite/MDAnalysisTests/analysis/test_results.py new file mode 100644 index 00000000000..97d299de101 --- /dev/null +++ b/testsuite/MDAnalysisTests/analysis/test_results.py @@ -0,0 +1,173 @@ +import pickle +from collections import UserDict + +import numpy as np +import pytest +from MDAnalysis.analysis import results as results_module +from numpy.testing import assert_equal + + +class Test_Results: + @pytest.fixture + def results(self): + return results_module.Results(a=1, b=2) + + def test_get(self, results): + assert results.a == results["a"] == 1 + + def test_no_attr(self, results): + msg = "'Results' object has no attribute 'c'" + with pytest.raises(AttributeError, match=msg): + results.c + + def test_set_attr(self, results): + value = [1, 2, 3, 4] + results.c = value + assert results.c is results["c"] is value + + def test_set_key(self, results): + value = [1, 2, 3, 4] + results["c"] = value + assert results.c is results["c"] is value + + @pytest.mark.parametrize("key", dir(UserDict) + ["data"]) + def test_existing_dict_attr(self, results, key): + msg = f"'{key}' is a protected dictionary attribute" + with pytest.raises(AttributeError, match=msg): + results[key] = None + + @pytest.mark.parametrize("key", dir(UserDict) + ["data"]) + def test_wrong_init_type(self, key): + msg = f"'{key}' is a protected dictionary attribute" + with pytest.raises(AttributeError, match=msg): + results_module.Results(**{key: None}) + + @pytest.mark.parametrize("key", ("0123", "0j", "1.1", "{}", "a b")) + def test_weird_key(self, results, key): + msg = f"'{key}' is not a valid attribute" + with pytest.raises(ValueError, match=msg): + results[key] = None + + def test_setattr_modify_item(self, results): + mylist = [1, 2] + mylist2 = [3, 4] + results.myattr = mylist + assert results.myattr is mylist + results["myattr"] = mylist2 + assert results.myattr is mylist2 + mylist2.pop(0) + assert len(results.myattr) == 1 + assert results.myattr is mylist2 + + def test_setitem_modify_item(self, results): + mylist = [1, 2] + mylist2 = [3, 4] + results["myattr"] = mylist + assert results.myattr is mylist + results.myattr = mylist2 + assert results.myattr is mylist2 + mylist2.pop(0) + assert len(results["myattr"]) == 1 + assert results["myattr"] is mylist2 + + def test_delattr(self, results): + assert hasattr(results, "a") + delattr(results, "a") + assert not hasattr(results, "a") + + def test_missing_delattr(self, results): + assert not hasattr(results, "d") + msg = "'Results' object has no attribute 'd'" + with pytest.raises(AttributeError, match=msg): + delattr(results, "d") + + def test_pop(self, results): + assert hasattr(results, "a") + results.pop("a") + assert not hasattr(results, "a") + + def test_update(self, results): + assert not hasattr(results, "spudda") + results.update({"spudda": "fett"}) + assert results.spudda == "fett" + + def test_update_data_fail(self, results): + msg = f"'data' is a protected dictionary attribute" + with pytest.raises(AttributeError, match=msg): + results.update({"data": 0}) + + def test_pickle(self, results): + results_p = pickle.dumps(results) + results_new = pickle.loads(results_p) + + @pytest.mark.parametrize( + "args, kwargs, length", + [ + (({"darth": "tater"},), {}, 1), + ([], {"darth": "tater"}, 1), + (({"darth": "tater"},), {"yam": "solo"}, 2), + (({"darth": "tater"},), {"darth": "vader"}, 1), + ], + ) + def test_initialize_arguments(self, args, kwargs, length): + results = results_module.Results(*args, **kwargs) + ref = dict(*args, **kwargs) + assert ref == results + assert len(results) == length + + def test_different_instances(self, results): + new_results = results_module.Results(darth="tater") + assert new_results.data is not results.data + + +class Test_ResultsGroup: + @pytest.fixture + def results_0(self): + return results_module.Results( + sequence=[0], + ndarray_mean=np.array([0, 0, 0]), + ndarray_sum=np.array([0, 0, 0, 0]), + float=0.0, + float_sum=0.0, + ) + + @pytest.fixture + def results_1(self): + return results_module.Results( + sequence=[1], + ndarray_mean=np.array([1, 1, 1]), + ndarray_sum=np.array([1, 1, 1, 1]), + float=1.0, + float_sum=1.0, + ) + + @pytest.fixture + def merger(self): + RG = results_module.ResultsGroup + lookup = { + "sequence": RG.flatten_sequence, + "ndarray_mean": RG.ndarray_mean, + "ndarray_sum": RG.ndarray_sum, + "float": RG.float_mean, + "float_sum": lambda floats: sum(floats), + } + return RG(lookup=lookup) + + @pytest.mark.parametrize("n", [1, 2, 5, 14]) + def test_all_results(self, results_0, results_1, merger, n): + from itertools import cycle + + objects = [obj for obj, _ in zip(cycle([results_0, results_1]), range(n))] + + arr = [i for _, i in zip(range(n), cycle([0, 1]))] + answers = { + "sequence": arr, + "ndarray_mean": [np.mean(arr) for _ in range(3)], + "ndarray_sum": [np.sum(arr) for _ in range(4)], + "float": np.mean(arr), + "float_sum": np.sum(arr), + } + + results = merger.merge(objects) + for attr, merged_value in results.items(): + assert_equal(merged_value, answers.get(attr), err_msg=f"{attr=}, {merged_value=}, {arr=}, {objects=}") diff --git a/testsuite/MDAnalysisTests/analysis/test_rms.py b/testsuite/MDAnalysisTests/analysis/test_rms.py index 56d2e459e94..d42993feb46 100644 --- a/testsuite/MDAnalysisTests/analysis/test_rms.py +++ b/testsuite/MDAnalysisTests/analysis/test_rms.py @@ -186,95 +186,104 @@ def correct_values_backbone_group(self): return [[0, 1, 0, 0, 0], [49, 50, 4.6997, 1.9154, 2.7139]] - def test_rmsd(self, universe, correct_values): + def test_rmsd(self, universe, correct_values, client_RMSD): + # client_RMSD is defined in testsuite/analysis/conftest.py + # among with other testing fixtures. During testing, it will + # collect all possible backends and reasonable number of workers + # for a given AnalysisBase subclass, and extend the tests + # to run with all of them. RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA') - RMSD.run(step=49) + RMSD.run(step=49, **client_RMSD) assert_almost_equal(RMSD.results.rmsd, correct_values, 4, err_msg="error: rmsd profile should match" + "test values") - def test_rmsd_frames(self, universe, correct_values): + def test_rmsd_frames(self, universe, correct_values, client_RMSD): RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA') - RMSD.run(frames=[0, 49]) + RMSD.run(frames=[0, 49], **client_RMSD) assert_almost_equal(RMSD.results.rmsd, correct_values, 4, err_msg="error: rmsd profile should match" + "test values") - def test_rmsd_unicode_selection(self, universe, correct_values): + def test_rmsd_unicode_selection(self, universe, correct_values, client_RMSD): RMSD = MDAnalysis.analysis.rms.RMSD(universe, select=u'name CA') - RMSD.run(step=49) + RMSD.run(step=49, **client_RMSD) assert_almost_equal(RMSD.results.rmsd, correct_values, 4, err_msg="error: rmsd profile should match" + "test values") - def test_rmsd_atomgroup_selections(self, universe): + def test_rmsd_atomgroup_selections(self, universe, client_RMSD): # see Issue #1684 R1 = MDAnalysis.analysis.rms.RMSD(universe.atoms, - select="resid 1-30").run() + select="resid 1-30").run(**client_RMSD) R2 = MDAnalysis.analysis.rms.RMSD(universe.atoms.select_atoms("name CA"), - select="resid 1-30").run() + select="resid 1-30").run(**client_RMSD) assert not np.allclose(R1.results.rmsd[:, 2], R2.results.rmsd[:, 2]) - def test_rmsd_single_frame(self, universe): + def test_rmsd_single_frame(self, universe, client_RMSD): RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA', - ).run(start=5, stop=6) + ).run(start=5, stop=6, **client_RMSD) single_frame = [[5, 6, 0.91544906]] assert_almost_equal(RMSD.results.rmsd, single_frame, 4, err_msg="error: rmsd profile should match" + "test values") - def test_mass_weighted(self, universe, correct_values): + def test_mass_weighted(self, universe, correct_values, client_RMSD): # mass weighting the CA should give the same answer as weighing # equally because all CA have the same mass RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='name CA', - weights='mass').run(step=49) + weights='mass').run(step=49, **client_RMSD) assert_almost_equal(RMSD.results.rmsd, correct_values, 4, err_msg="error: rmsd profile should match" "test values") - def test_custom_weighted(self, universe, correct_values_mass): - RMSD = MDAnalysis.analysis.rms.RMSD(universe, weights="mass").run(step=49) + def test_custom_weighted(self, universe, correct_values_mass, client_RMSD): + RMSD = MDAnalysis.analysis.rms.RMSD(universe, weights="mass").run(step=49, **client_RMSD) assert_almost_equal(RMSD.results.rmsd, correct_values_mass, 4, err_msg="error: rmsd profile should match" "test values") - def test_weights_mass_is_mass_weighted(self, universe): + def test_weights_mass_is_mass_weighted(self, universe, client_RMSD): RMSD_mass = MDAnalysis.analysis.rms.RMSD(universe, - weights="mass").run(step=49) + weights="mass").run(step=49, **client_RMSD) RMSD_cust = MDAnalysis.analysis.rms.RMSD(universe, - weights=universe.atoms.masses).run(step=49) + weights=universe.atoms.masses).run(step=49, **client_RMSD) assert_almost_equal(RMSD_mass.results.rmsd, RMSD_cust.results.rmsd, 4, err_msg="error: rmsd profiles should match for 'mass' " "and universe.atoms.masses") - def test_custom_weighted_list(self, universe, correct_values_mass): + def test_custom_weighted_list(self, universe, correct_values_mass, client_RMSD): weights = universe.atoms.masses RMSD = MDAnalysis.analysis.rms.RMSD(universe, - weights=list(weights)).run(step=49) + weights=list(weights)).run(step=49, **client_RMSD) assert_almost_equal(RMSD.results.rmsd, correct_values_mass, 4, err_msg="error: rmsd profile should match" + "test values") - def test_custom_groupselection_weights_applied_1D_array(self, universe): + def test_custom_groupselection_weights_applied_1D_array(self, universe, client_RMSD): RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='backbone', groupselections=['name CA and resid 1-5', 'name CA and resid 1'], weights=None, - weights_groupselections=[[1, 0, 0, 0, 0], None]).run(step=49) + weights_groupselections=[[1, 0, 0, 0, 0], None]).run(step=49, + **client_RMSD + ) assert_almost_equal(RMSD.results.rmsd.T[3], RMSD.results.rmsd.T[4], 4, err_msg="error: rmsd profile should match " "for applied weight array and selected resid") - def test_custom_groupselection_weights_applied_mass(self, universe, correct_values_mass): + def test_custom_groupselection_weights_applied_mass(self, universe, correct_values_mass, client_RMSD): RMSD = MDAnalysis.analysis.rms.RMSD(universe, select='backbone', groupselections=['all', 'all'], weights=None, weights_groupselections=['mass', - universe.atoms.masses]).run(step=49) + universe.atoms.masses]).run(step=49, + **client_RMSD + ) assert_almost_equal(RMSD.results.rmsd.T[3], RMSD.results.rmsd.T[4], 4, err_msg="error: rmsd profile should match " @@ -315,22 +324,23 @@ def test_rmsd_list_of_weights_wrong_length(self, universe): weights='mass', weights_groupselections=[None]) - def test_rmsd_group_selections(self, universe, correct_values_group): + def test_rmsd_group_selections(self, universe, correct_values_group, client_RMSD): RMSD = MDAnalysis.analysis.rms.RMSD(universe, groupselections=['backbone', 'name CA'] - ).run(step=49) + ).run(step=49, **client_RMSD) assert_almost_equal(RMSD.results.rmsd, correct_values_group, 4, err_msg="error: rmsd profile should match" "test values") def test_rmsd_backbone_and_group_selection(self, universe, - correct_values_backbone_group): + correct_values_backbone_group, + client_RMSD): RMSD = MDAnalysis.analysis.rms.RMSD( universe, reference=universe, select="backbone", groupselections=['backbone and resid 1:10', - 'backbone and resid 10:20']).run(step=49) + 'backbone and resid 10:20']).run(step=49, **client_RMSD) assert_almost_equal( RMSD.results.rmsd, correct_values_backbone_group, 4, err_msg="error: rmsd profile should match test values") @@ -349,7 +359,7 @@ def test_mass_mismatches(self, universe): RMSD = MDAnalysis.analysis.rms.RMSD(universe, reference=reference) - def test_ref_mobile_mass_mismapped(self, universe,correct_values_mass_add_ten): + def test_ref_mobile_mass_mismapped(self, universe,correct_values_mass_add_ten, client_RMSD): reference = MDAnalysis.Universe(PSF, DCD) universe.atoms.masses = universe.atoms.masses + 10 RMSD = MDAnalysis.analysis.rms.RMSD(universe, @@ -357,7 +367,7 @@ def test_ref_mobile_mass_mismapped(self, universe,correct_values_mass_add_ten): select='all', weights='mass', tol_mass=100) - RMSD.run(step=49) + RMSD.run(step=49, **client_RMSD) assert_almost_equal(RMSD.results.rmsd, correct_values_mass_add_ten, 4, err_msg="error: rmsd profile should match " "between true values and calculated values") @@ -370,9 +380,9 @@ def test_group_selections_unequal_len(self, universe): reference=reference, groupselections=['resname MET', 'type NH3']) - def test_rmsd_attr_warning(self, universe): + def test_rmsd_attr_warning(self, universe, client_RMSD): RMSD = MDAnalysis.analysis.rms.RMSD( - universe, select='name CA').run(stop=2) + universe, select='name CA').run(stop=2, **client_RMSD) wmsg = "The `rmsd` attribute was deprecated in MDAnalysis 2.0.0" with pytest.warns(DeprecationWarning, match=wmsg): @@ -384,22 +394,22 @@ class TestRMSF(object): def universe(self): return mda.Universe(GRO, XTC) - def test_rmsf(self, universe): + def test_rmsf(self, universe, client_RMSF): rmsfs = rms.RMSF(universe.select_atoms('name CA')) - rmsfs.run() + rmsfs.run(**client_RMSF) test_rmsfs = np.load(rmsfArray) assert_almost_equal(rmsfs.results.rmsf, test_rmsfs, 5, err_msg="error: rmsf profile should match test " "values") - def test_rmsf_single_frame(self, universe): - rmsfs = rms.RMSF(universe.select_atoms('name CA')).run(start=5, stop=6) + def test_rmsf_single_frame(self, universe, client_RMSF): + rmsfs = rms.RMSF(universe.select_atoms('name CA')).run(start=5, stop=6, **client_RMSF) assert_almost_equal(rmsfs.results.rmsf, 0, 5, err_msg="error: rmsfs should all be zero") - def test_rmsf_identical_frames(self, universe, tmpdir): + def test_rmsf_identical_frames(self, universe, tmpdir, client_RMSF): outfile = os.path.join(str(tmpdir), 'rmsf.xtc') @@ -410,13 +420,35 @@ def test_rmsf_identical_frames(self, universe, tmpdir): universe = mda.Universe(GRO, outfile) rmsfs = rms.RMSF(universe.select_atoms('name CA')) - rmsfs.run() + rmsfs.run(**client_RMSF) assert_almost_equal(rmsfs.results.rmsf, 0, 5, err_msg="error: rmsfs should all be 0") - def test_rmsf_attr_warning(self, universe): - rmsfs = rms.RMSF(universe.select_atoms('name CA')).run(stop=2) + def test_rmsf_attr_warning(self, universe, client_RMSF): + rmsfs = rms.RMSF(universe.select_atoms('name CA')).run(stop=2, **client_RMSF) wmsg = "The `rmsf` attribute was deprecated in MDAnalysis 2.0.0" with pytest.warns(DeprecationWarning, match=wmsg): assert_equal(rmsfs.rmsf, rmsfs.results.rmsf) + + +@pytest.mark.parametrize( + "classname,is_parallelizable", + [ + (MDAnalysis.analysis.rms.RMSD, True), + (MDAnalysis.analysis.rms.RMSF, False), + ] +) +def test_not_parallelizable(classname, is_parallelizable): + assert classname._analysis_algorithm_is_parallelizable == is_parallelizable + + +@pytest.mark.parametrize( + "classname,backends", + [ + (MDAnalysis.analysis.rms.RMSD, ('serial', 'multiprocessing', 'dask',)), + (MDAnalysis.analysis.rms.RMSF, ('serial',)), + ] +) +def test_supported_backends(classname, backends): + assert classname.get_supported_backends() == backends diff --git a/testsuite/MDAnalysisTests/lib/test_util.py b/testsuite/MDAnalysisTests/lib/test_util.py index 4ff1832b7de..cd641133586 100644 --- a/testsuite/MDAnalysisTests/lib/test_util.py +++ b/testsuite/MDAnalysisTests/lib/test_util.py @@ -2232,3 +2232,11 @@ def test_which(): with pytest.warns(DeprecationWarning, match=wmsg): assert util.which('python') == shutil.which('python') + + +@pytest.mark.parametrize( + "modulename,is_installed", + [("math", True), ("sys", True), ("some_weird_module_name", False)], +) +def test_is_installed(modulename, is_installed): + assert util.is_installed(modulename) == is_installed