diff --git a/.circleci/ci-oldest-reqs.txt b/.circleci/ci-oldest-reqs.txt index b3abb266b..43842ea99 100644 --- a/.circleci/ci-oldest-reqs.txt +++ b/.circleci/ci-oldest-reqs.txt @@ -6,6 +6,7 @@ h5py==2.8.0 numpy==1.13.3 packaging==15.0 pandas==1.0.0 +pymongo==3.0.0 pytest-cov==2.10.1 pytest==6.2.1 tables==3.3.0 diff --git a/.circleci/config.yml b/.circleci/config.yml index 6a6a2568e..f8e80a666 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -46,6 +46,8 @@ jobs: linux-python-39: &linux-template docker: - image: circleci/python:3.9 + - image: circleci/mongo:latest + - image: circleci/redis:latest environment: BENCHMARKS: "RUN" @@ -138,16 +140,22 @@ jobs: <<: *linux-template docker: - image: circleci/python:3.8 + - image: circleci/mongo:latest + - image: circleci/redis:latest linux-python-37: <<: *linux-template docker: - image: circleci/python:3.7 + - image: circleci/mongo:latest + - image: circleci/redis:latest linux-python-36-oldest: <<: *linux-template docker: - image: circleci/python:3.6 + - image: circleci/mongo:latest + - image: circleci/redis:latest environment: BENCHMARKS: "SKIP" DEPENDENCIES: "OLDEST" @@ -180,7 +188,12 @@ jobs: ${PYTHON} -m pip install --progress-bar off -U pip>=20.3 ${PYTHON} -m pip install --progress-bar off -U codecov ${PYTHON} -m pip install --progress-bar off -U -r requirements/requirements-test.txt - ${PYTHON} -m pip install --progress-bar off -U -r requirements/requirements-test-optional.txt + + # For some reason Zarr doesn't install correctly on Windows (runs + # into pip SSL errors), so we skip that test. + grep -v zarr requirements-test-optional.txt > requirements-test-optional-windows.txt + ${PYTHON} -m pip install --progress-bar off -U -r requirements-test-optional-windows.txt + ${PYTHON} -m pip install --progress-bar off -U -e . - run: name: Run tests diff --git a/changelog.txt b/changelog.txt index 5c750fc41..0268a6510 100644 --- a/changelog.txt +++ b/changelog.txt @@ -13,6 +13,7 @@ next Added +++++ + - New ``SyncedCollection`` class and subclasses to replace ``JSONDict`` with more general support for different types of resources (such as MongoDB collections or Redis databases) and more complete support for different data types synchronized with files (#196, #234, #249, #316, #383, #397, #465, #484). This change introduces a minor-backwards incompatible change; for users making direct use of signac buffering, the ``force_write`` parameter is no longer respected. If the argument is passed, a warning will now be raised to indicate that it is ignored and will be removed in signac 2.0. - Unified querying for state point and document filters using 'sp' and 'doc' as prefixes (#332, #514). This change introduces a minor backwards-incompatible change to the ``Collection`` index schema ('statepoint'->'sp'), but this does not affect any APIs, only indexes saved to file using a previous version of signac. Indexing APIs will be removed in signac 2.0. diff --git a/doc/api.rst b/doc/api.rst index dc0920f40..b1ec535eb 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -206,3 +206,111 @@ signac.errors module :members: :undoc-members: :show-inheritance: + +synced\_collections package +=========================== + +Data Types +---------- + +synced\_collections.synced\_collection module ++++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.data_types.synced_collection + :members: + :private-members: + :show-inheritance: + +synced\_collections.synced\_dict module ++++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.data_types.synced_dict + :members: + :show-inheritance: + +synced\_collections.synced\_list module ++++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.data_types.synced_list + :members: + :show-inheritance: + +Backends +-------- + +synced\_collections.backends.collection\_json module ++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.backends.collection_json + :members: + :show-inheritance: + +synced\_collections.backends.collection\_mongodb module +++++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.backends.collection_mongodb + :members: + :show-inheritance: + +synced\_collections.backends.collection\_redis module +++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.backends.collection_redis + :members: + :show-inheritance: + +synced\_collections.backends.collection\_zarr module ++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.backends.collection_zarr + :members: + :show-inheritance: + +Buffers +------- + +synced\_collections.buffers.buffered\_collection module ++++++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.buffers.buffered_collection + :members: + :private-members: + :show-inheritance: + +synced\_collections.buffers.file\_buffered\_collection module ++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.buffers.file_buffered_collection + :members: + :show-inheritance: + +synced\_collections.buffers.serialized\_file\_buffered\_collection module ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.buffers.serialized_file_buffered_collection + :members: + :show-inheritance: + +synced\_collections.buffers.memory\_buffered\_collection module ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.buffers.memory_buffered_collection + :members: + :show-inheritance: + +Miscellaneous Modules +--------------------- + +synced\_collections.utils module +++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.utils + :members: + :show-inheritance: + +synced\_collections.validators module ++++++++++++++++++++++++++++++++++++++ + +.. automodule:: signac.synced_collections.validators + :members: + :show-inheritance: diff --git a/doc/conf.py b/doc/conf.py index a0653a95f..5dad8c15b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -316,6 +316,9 @@ def __getattr__(cls, name): intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "pymongo": ("https://pymongo.readthedocs.io/en/stable/", None), - "pandas": ("https://pandas.pydata.org/docs/", None), - "h5py": ("https://docs.h5py.org/en/stable/", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "h5py": ("http://docs.h5py.org/en/stable/", None), + "zarr": ("https://zarr.readthedocs.io/en/stable", None), + "redis": ("https://redis-py.readthedocs.io/en/stable/", None), + "numcodecs": ("https://numcodecs.readthedocs.io/en/stable/", None), } diff --git a/requirements/requirements-test-optional.txt b/requirements/requirements-test-optional.txt index fdd0e97a1..a5a7ced7b 100644 --- a/requirements/requirements-test-optional.txt +++ b/requirements/requirements-test-optional.txt @@ -2,4 +2,7 @@ h5py==3.1.0; implementation_name=='cpython' numpy==1.20.0 pandas==1.2.1; implementation_name=='cpython' pymongo==3.11.2; implementation_name=='cpython' +redis==3.5.3 +ruamel.yaml==0.16.12 tables==3.6.1; implementation_name=='cpython' +zarr==2.4.0; platform_system!='Windows' diff --git a/setup.cfg b/setup.cfg index b3cad99c0..e5c0914d2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,7 +20,7 @@ ignore = E123,E126,E203,E226,E241,E704,W503,W504 match = ^((?!\.sync-zenodo-metadata|setup|benchmark|mpipool|connection|crypt|host|filesystems|indexing).)*\.py$ match-dir = ^((?!\.|tests|configobj|db).)*$ ignore-decorators = "deprecated" -ignore = D105, D107, D203, D204, D213 +add-ignore = D105, D107, D203, D204, D213 [mypy] ignore_missing_imports = True @@ -34,8 +34,8 @@ omit = */signac/common/configobj/*.py [tool:pytest] -filterwarnings = - ignore: .*[The indexing module | get_statepoint] is deprecated.*: DeprecationWarning +filterwarnings = + ignore: .*[The indexing module | get_statepoint | Use of.+as key] is deprecated.*: DeprecationWarning [bumpversion:file:setup.py] diff --git a/signac/__init__.py b/signac/__init__.py index 6bf487e2f..56d290b78 100644 --- a/signac/__init__.py +++ b/signac/__init__.py @@ -27,15 +27,21 @@ from .contrib import filesystems as fs from .contrib import get_job, get_project, index, index_files, init_project from .core.h5store import H5Store, H5StoreManager -from .core.jsondict import JSONDict -from .core.jsondict import buffer_reads_writes as buffered from .core.jsondict import flush_all as flush -from .core.jsondict import get_buffer_load, get_buffer_size -from .core.jsondict import in_buffered_mode as is_buffered from .db import get_database from .diff import diff_jobs +from .synced_collections.backends.collection_json import ( + BufferedJSONAttrDict as JSONDict, +) from .version import __version__ +# Alias some properties related to buffering into the signac namespace. +buffered = JSONDict.buffer_backend +is_buffered = JSONDict.backend_is_buffered +get_buffer_load = JSONDict.get_current_buffer_size +get_buffer_size = JSONDict.get_buffer_capacity +set_buffer_size = JSONDict.set_buffer_capacity + __all__ = [ "__version__", "contrib", @@ -69,6 +75,7 @@ "flush", "get_buffer_size", "get_buffer_load", + "set_buffer_size", "JSONDict", "H5Store", "H5StoreManager", diff --git a/signac/contrib/collection.py b/signac/contrib/collection.py index 51b89576e..a28e2dbcb 100644 --- a/signac/contrib/collection.py +++ b/signac/contrib/collection.py @@ -16,6 +16,7 @@ import argparse import io +import json import logging import operator import re @@ -24,7 +25,6 @@ from math import isclose from numbers import Number -from ..core import json from .filterparse import parse_filter_arg from .utility import _nested_dicts_to_dotted_keys, _to_hashable diff --git a/signac/contrib/errors.py b/signac/contrib/errors.py index 80c3da855..0391a2a96 100644 --- a/signac/contrib/errors.py +++ b/signac/contrib/errors.py @@ -28,8 +28,8 @@ class DestinationExistsError(Error, RuntimeError): Parameters ---------- - destination : - The destination object causing the error. + destination : str + The destination causing the error. """ diff --git a/signac/contrib/filterparse.py b/signac/contrib/filterparse.py index c21506492..d52025dc7 100644 --- a/signac/contrib/filterparse.py +++ b/signac/contrib/filterparse.py @@ -3,11 +3,10 @@ # This software is licensed under the BSD 3-Clause License. """Parse the filter arguments.""" +import json import sys from collections.abc import Mapping -from ..core import json - def _print_err(msg=None): """Print the provided message to stderr. diff --git a/signac/contrib/hashing.py b/signac/contrib/hashing.py index a2c5a58bd..79e3c082c 100644 --- a/signac/contrib/hashing.py +++ b/signac/contrib/hashing.py @@ -6,6 +6,8 @@ import hashlib import json +from ..synced_collections.utils import SyncedCollectionJSONEncoder + # We must use the standard library json for exact consistency in formatting @@ -27,7 +29,7 @@ def calc_id(spec): Encoded hash in hexadecimal format. """ - blob = json.dumps(spec, sort_keys=True) + blob = json.dumps(spec, cls=SyncedCollectionJSONEncoder, sort_keys=True) m = hashlib.md5() m.update(blob.encode()) return m.hexdigest() diff --git a/signac/contrib/import_export.py b/signac/contrib/import_export.py index 98008570d..372a02d59 100644 --- a/signac/contrib/import_export.py +++ b/signac/contrib/import_export.py @@ -4,6 +4,7 @@ """Provides features for importing and exporting data.""" import errno +import json import logging import os import re @@ -16,7 +17,6 @@ from tempfile import TemporaryDirectory from zipfile import ZIP_DEFLATED, ZipFile -from ..core import json from .errors import DestinationExistsError, StatepointParsingError from .utility import _dotted_dict_to_nested_dicts, _mkdir_p @@ -766,7 +766,7 @@ def _copy_to_job_workspace(src, job, copytree): raise DestinationExistsError(job) raise else: - job._init() + job.init() return dst diff --git a/signac/contrib/job.py b/signac/contrib/job.py index f4a0135a5..fa117c096 100644 --- a/signac/contrib/job.py +++ b/signac/contrib/job.py @@ -8,14 +8,19 @@ import os import shutil from copy import deepcopy +from json import JSONDecodeError +from typing import FrozenSet from deprecation import deprecated -from ..core import json -from ..core.attrdict import SyncedAttrDict from ..core.h5store import H5StoreManager -from ..core.jsondict import JSONDict from ..sync import sync_jobs +from ..synced_collections.backends.collection_json import ( + BufferedJSONAttrDict, + JSONAttrDict, + json_attr_dict_validator, +) +from ..synced_collections.errors import KeyTypeError from ..version import __version__ from .errors import DestinationExistsError, JobsCorruptedError from .hashing import calc_id @@ -24,31 +29,195 @@ logger = logging.getLogger(__name__) -class _sp_save_hook: - """Hook to handle job migration when state points are changed. +# Note: All children of _StatePointDict will be of its parent type because they +# share a backend and the SyncedCollection registry parses the classes in order +# of registration. _If_ we need more control over this, that process can be +# exposed more thoroughly and registration can be made explicit rather than +# implicit, but for now the existing behavior works fine. +class _StatePointDict(JSONAttrDict): + """A JSON-backed dictionary for storing job state points. + + There are three principal reasons for extending the base JSONAttrDict: + 1. Saving needs to trigger a job directory migration, and + 2. State points are assumed to not support external modification, so + they never need to load from disk _except_ the very first time a job + is opened by id and the state point is not present in the cache. + 3. It must be possible to load and/or save on demand during tasks like + job directory migrations. + """ - When a job's state point is changed, in addition - to the contents of the file being modified this hook - calls :meth:`~Job._reset_sp` to rehash the state - point, compute a new job id, and move the folder. + _PROTECTED_KEYS: FrozenSet[str] = JSONAttrDict._PROTECTED_KEYS.union(("_jobs",)) + _all_validators = (json_attr_dict_validator,) + + def __init__( + self, + jobs=None, + filename=None, + write_concern=False, + data=None, + parent=None, + *args, + **kwargs, + ): + # Multiple Python Job objects can share a single `_StatePointDict` + # instance because they are shallow copies referring to the same data + # on disk. We need to store these jobs in a shared list here so that + # shallow copies can point to the same place and trigger each other to + # update. This does not apply to independently created Job objects, + # even if they refer to the same disk data; this only applies to + # explicit shallow copies and unpickled objects within a session. + self._jobs = list(jobs) + super().__init__( + filename=filename, + write_concern=write_concern, + data=data, + parent=parent, + *args, + **kwargs, + ) - Parameters - ---------- - jobs : iterable of `Jobs` - List of jobs(instance of `Job`). + def _load(self): + # State points never load from disk automatically. They are either + # initialized with provided data (e.g. from the state point cache), or + # they load from disk the first time state point data is requested for + # a Job opened by id (in which case the state point must first be + # validated manually). + pass - """ + def _save(self): + # State point modification triggers job migration for all jobs sharing + # this state point (shallow copies of a single job). + new_id = calc_id(self) - def __init__(self, *jobs): - self.jobs = list(jobs) + # All elements of the job list are shallow copies of each other, so any + # one of them is representative. + job = next(iter(self._jobs)) + old_id = job._id + if old_id == new_id: + return - def load(self): - pass + tmp_statepoint_file = self.filename + "~" + should_init = False + try: + # Move the state point to an intermediate location as a backup. + os.replace(self.filename, tmp_statepoint_file) + try: + new_workspace = os.path.join(job._project.workspace(), new_id) + os.replace(job.workspace(), new_workspace) + except OSError as error: + os.replace(tmp_statepoint_file, self.filename) # rollback + if error.errno in (errno.EEXIST, errno.ENOTEMPTY, errno.EACCES): + raise DestinationExistsError(new_id) + else: + raise + else: + should_init = True + except OSError as error: + # The most likely reason we got here is because the state point + # file move failed due to the job not being initialized so the file + # doesn't exist, which is OK. + if error.errno != errno.ENOENT: + raise + + # Update each job instance. + for job in self._jobs: + job._id = new_id + job._initialize_lazy_properties() + + # Remove the temporary state point file if it was created. Have to do it + # here because we need to get the updated job state point filename. + try: + os.remove(job._statepoint_filename + "~") + except OSError as error: + if error.errno != errno.ENOENT: + raise + + # Since all the jobs are equivalent, just grab the filename from the + # last one and init it. Also migrate the lock for multithreaded support. + old_lock_id = self._lock_id + self._filename = job._statepoint_filename + type(self)._locks[self._lock_id] = type(self)._locks.pop(old_lock_id) + + if should_init: + # Only initializing one job assumes that all changes in init are + # changes reflected in the underlying resource (the JSON file). + # This assumption is currently valid because all in-memory + # attributes are loaded lazily (and are handled by the call to + # _initialize_lazy_properties above), except for the key defining + # property of the job id (which is also updated above). If init + # ever changes to making modifications to the job object, we may + # need to call it for all jobs. + job.init() + + logger.info(f"Moved '{old_id}' -> '{new_id}'.") + + def save(self, force=False): + """Trigger a save to disk. + + Unlike normal JSONAttrDict objects, this class requires the ability to save + on command. Moreover, this save must be conditional on whether or not a + file is present to allow the user to observe state points in corrupted + data spaces and attempt to recover. + + Parameters + ---------- + force : bool + If True, save even if the file is present on disk. + """ + try: + # Open the file for writing only if it does not exist yet. + if force or not os.path.isfile(self._filename): + super()._save() + except Exception as error: + if not isinstance(error, OSError) or error.errno not in ( + errno.EEXIST, + errno.EACCES, + ): + # Attempt to delete the file on error, to prevent corruption. + # OSErrors that are EEXIST or EACCES don't need to delete the file. + try: + os.remove(self._filename) + except Exception: # ignore all errors here + pass + raise + + def load(self, job_id): + """Trigger a load from disk. + + Unlike normal JSONAttrDict objects, this class requires the ability to + load on command. These loads typically occur when the state point + must be validated against the data on disk; at all other times, the + in-memory data is assumed to be accurate to avoid unnecessary I/O. + + Parameters + ---------- + job_id : str + Job id used to validate contents on disk. - def save(self): - """Reset the state point for all the jobs.""" - for job in self.jobs: - job._reset_sp() + Returns + ------- + data : dict + Dictionary of state point data. + + Raises + ------ + :class:`~signac.errors.JobsCorruptedError` + If the data on disk is invalid or its hash does not match the job + id. + + """ + try: + data = self._load_from_resource() + except JSONDecodeError: + raise JobsCorruptedError([job_id]) + + if calc_id(data) != job_id: + raise JobsCorruptedError([job_id]) + + with self._suspend_sync: + self._update(data, _validate=False) + + return data class Job: @@ -73,9 +242,9 @@ class Job: """ FN_MANIFEST = "signac_statepoint.json" - """The job's manifest filename. + """The job's state point filename. - The job manifest is a human-readable file containing the job's state + The job state point is a human-readable file containing the job's state point that is stored in each job's workspace directory. """ @@ -87,38 +256,32 @@ class Job: def __init__(self, project, statepoint=None, _id=None): self._project = project + self._initialize_lazy_properties() if statepoint is None and _id is None: raise ValueError("Either statepoint or _id must be provided.") elif statepoint is not None: - # A state point was provided. - self._statepoint = SyncedAttrDict(statepoint, parent=_sp_save_hook(self)) - # If the id is provided, assume the job is already registered in - # the project cache and that the id is valid for the state point. - if _id is None: - # Validate the state point and recursively convert to supported types. - statepoint = self.statepoint() - # Compute the id from the state point if not provided. - self._id = calc_id(statepoint) - # Update the project's state point cache immediately if opened by state point - self._project._register(self.id, statepoint) - else: - self._id = _id + self._statepoint_requires_init = False + try: + self._id = calc_id(statepoint) if _id is None else _id + except TypeError: + raise KeyTypeError + self._statepoint = _StatePointDict( + jobs=[self], filename=self._statepoint_filename, data=statepoint + ) + + # Update the project's state point cache immediately if opened by state point + self._project._register(self.id, statepoint) else: # Only an id was provided. State point will be loaded lazily. - self._statepoint = None self._id = _id + self._statepoint_requires_init = True - # Prepare job working directory + def _initialize_lazy_properties(self): + """Initialize all properties that are designed to be loaded lazily.""" self._wd = None - - # Prepare job document self._document = None - - # Prepare job H5StoreManager self._stores = None - - # Prepare current working directory for context management self._cwd = [] @deprecated( @@ -181,9 +344,16 @@ def workspace(self): """ if self._wd is None: - self._wd = os.path.join(self._project.workspace(), self.id) + # We can rely on the project workspace to be well-formed, so just + # use string-concatenation with os.sep instead of os.path.join for speed. + self._wd = self._project.workspace() + os.sep + self.id return self._wd + @property + def _statepoint_filename(self): + """Get the path of the state point file for this job.""" + return self.workspace() + os.sep + self.FN_MANIFEST + @property def ws(self): """Alias for :meth:`~Job.workspace`.""" @@ -195,7 +365,8 @@ def reset_statepoint(self, new_statepoint): This method will change the job id if the state point has been altered. For more information, see - `Modifying the State Point `_. + `Modifying the State Point + `_. .. danger:: @@ -208,50 +379,18 @@ def reset_statepoint(self, new_statepoint): new_statepoint : dict The job's new state point. - """ # noqa: E501 - dst = self._project.open_job(new_statepoint) - if dst == self: - return - fn_manifest = os.path.join(self.workspace(), self.FN_MANIFEST) - fn_manifest_backup = fn_manifest + "~" - try: - os.replace(fn_manifest, fn_manifest_backup) - try: - os.replace(self.workspace(), dst.workspace()) - except OSError as error: - os.replace(fn_manifest_backup, fn_manifest) # rollback - if error.errno in (errno.EEXIST, errno.ENOTEMPTY, errno.EACCES): - raise DestinationExistsError(dst) - else: - raise - else: - dst.init() - except OSError as error: - if error.errno == errno.ENOENT: - pass # job is not initialized - else: - raise - # Update this instance - self.statepoint._data = dst.statepoint._data - self._id = dst._id - self._wd = None - self._document = None - self._stores = None - self._cwd = [] - logger.info(f"Moved '{self}' -> '{dst}'.") - - def _reset_sp(self, new_statepoint=None): - """Check for new state point requested to assign this job. - - Parameters - ---------- - new_statepoint : dict - The job's new state point (Default value = None). - """ - if new_statepoint is None: - new_statepoint = self.statepoint() - self.reset_statepoint(new_statepoint) + if self._statepoint_requires_init: + # Instantiate state point data lazily - no load is required, since + # we are provided with the new state point data. + self._statepoint = _StatePointDict( + jobs=[self], filename=self._statepoint_filename + ) + self._statepoint_requires_init = False + self.statepoint.reset(new_statepoint) + + # Update the project's state point cache when loaded lazily + self._project._register(self.id, new_statepoint) def update_statepoint(self, update, overwrite=False): """Change the state point of this job while preserving job data. @@ -295,39 +434,13 @@ def update_statepoint(self, update, overwrite=False): if not overwrite: for key, value in update.items(): if statepoint.get(key, value) != value: - raise KeyError(key) + raise KeyError( + f"Key {key} was provided but already exists in the " + "mapping with another value." + ) statepoint.update(update) self.reset_statepoint(statepoint) - def _read_manifest(self): - """Read and parse the manifest file, if it exists. - - Returns - ------- - manifest : dict - State point data. - - Raises - ------ - :class:`~signac.errors.JobsCorruptedError` - If an error occurs while parsing the state point manifest. - OSError - If an error occurs while reading the state point manifest. - - """ - fn_manifest = os.path.join(self.workspace(), self.FN_MANIFEST) - try: - with open(fn_manifest, "rb") as file: - manifest = json.loads(file.read().decode()) - except OSError as error: - if error.errno != errno.ENOENT: - raise error - except ValueError: - # This catches JSONDecodeError, a subclass of ValueError - raise JobsCorruptedError([self.id]) - else: - return manifest - @property def statepoint(self): """Get the job's state point. @@ -340,7 +453,8 @@ def statepoint(self): modifiable copy that will not modify the underlying JSON file, you can access a dict copy of the state point by calling it, e.g. ``sp_dict = job.statepoint()`` instead of ``sp = job.statepoint``. - For more information, see : :class:`~signac.JSONDict`. + For more information, see + :class:`~signac.synced_collections.backends.collection_json.JSONAttrDict`. See :ref:`signac statepoint ` for the command line equivalent. @@ -350,12 +464,16 @@ def statepoint(self): Returns the job's state point. """ - if self._statepoint is None: - # Load state point manifest lazily and assign to - # self._statepoint - statepoint = self._check_manifest() + if self._statepoint_requires_init: + # Load state point data lazily (on access). + self._statepoint = _StatePointDict( + jobs=[self], filename=self._statepoint_filename + ) + statepoint = self._statepoint.load(self.id) + # Update the project's state point cache when loaded lazily self._project._register(self.id, statepoint) + self._statepoint_requires_init = False return self._statepoint @@ -369,7 +487,7 @@ def statepoint(self, new_statepoint): The new state point to be assigned. """ - self._reset_sp(new_statepoint) + self.reset_statepoint(new_statepoint) @property def sp(self): @@ -387,9 +505,13 @@ def document(self): .. warning:: + Even deep copies of :attr:`~Job.document` will modify the same file, + so changes will still effectively be persisted between deep copies. If you need a deep copy that will not modify the underlying - persistent JSON file, use :attr:`~Job.document` instead of :attr:`~Job.doc`. - For more information, see :attr:`~Job.statepoint` or :class:`~signac.JSONDict`. + persistent JSON file, use the call operator to get an equivalent + plain dictionary: ``job.document()``. + For more information, see + :class:`~signac.JSONDict`. See :ref:`signac document ` for the command line equivalent. @@ -402,16 +524,16 @@ def document(self): if self._document is None: self.init() fn_doc = os.path.join(self.workspace(), self.FN_DOCUMENT) - self._document = JSONDict(filename=fn_doc, write_concern=True) + self._document = BufferedJSONAttrDict(filename=fn_doc, write_concern=True) return self._document @document.setter def document(self, new_doc): - """Assign new document to the this job. + """Assign new document data to this job. Parameters ---------- - new_doc : :class:`~signac.JSONDict` + new_doc : dict The job document handle. """ @@ -423,9 +545,18 @@ def doc(self): .. warning:: + Even deep copies of :attr:`~Job.doc` will modify the same file, so + changes will still effectively be persisted between deep copies. If you need a deep copy that will not modify the underlying - persistent JSON file, use :attr:`~Job.document` instead of :attr:`~Job.doc`. - For more information, see :attr:`~Job.statepoint` or :class:`~signac.JSONDict`. + persistent JSON file, use the call operator to get an equivalent + plain dictionary: ``job.doc()``. + + See :ref:`signac document ` for the command line equivalent. + + Returns + ------- + :class:`~signac.JSONDict` + The job document handle. """ return self.document @@ -507,96 +638,11 @@ def data(self, new_data): """ self.stores[self.KEY_DATA] = new_data - def _init(self, force=False): - """Contains all logic for job initialization. - - This method is called by :meth:`~.init` and is responsible - for actually creating the job workspace directory and - writing out the state point manifest file. - - Parameters - ---------- - force : bool - If ``True``, write the job manifest even if it - already exists. If ``False``, this method will - raise an Exception if the manifest exists - (Default value = False). - - """ - # Attempt early exit if the manifest exists and is valid - try: - statepoint = self._check_manifest() - except Exception: - # Any exception means this method cannot exit early. - - # Create the workspace directory if it does not exist. - try: - _mkdir_p(self.workspace()) - except OSError: - logger.error( - "Error occurred while trying to create " - "workspace directory for job '{}'.".format(self.id) - ) - raise - - fn_manifest = os.path.join(self.workspace(), self.FN_MANIFEST) - try: - # Prepare the data before file creation and writing. - statepoint = self.statepoint() - blob = json.dumps(statepoint, indent=2) - except JobsCorruptedError: - raise - - try: - # Open the file for writing only if it does not exist yet. - with open(fn_manifest, "w" if force else "x") as file: - file.write(blob) - except OSError as error: - if error.errno not in (errno.EEXIST, errno.EACCES): - raise - except Exception as error: - # Attempt to delete the file on error, to prevent corruption. - try: - os.remove(fn_manifest) - except Exception: # ignore all errors here - pass - raise error - else: - # Validate the output again after writing to disk - statepoint = self._check_manifest() - - # Update the project's state point cache if the manifest is valid - self._project._register(self.id, statepoint) - - def _check_manifest(self): - """Check whether the manifest file exists and is correct. - - If the manifest is valid, this sets the state point if it is not - already set. - - Returns - ------- - manifest : dict - State point data. - - Raises - ------ - :class:`~signac.errors.JobsCorruptedError` - If the manifest hash is not equal to the job id. - - """ - manifest = self._read_manifest() - if calc_id(manifest) != self.id: - raise JobsCorruptedError([self.id]) - if self._statepoint is None: - self._statepoint = SyncedAttrDict(manifest, parent=_sp_save_hook(self)) - return manifest - def init(self, force=False): """Initialize the job's workspace directory. - This function will do nothing if the directory and - the job manifest already exist. + This function will do nothing if the directory and the job state point + already exist and the state point is valid. Returns the calling job. @@ -605,20 +651,50 @@ def init(self, force=False): Parameters ---------- force : bool - Overwrite any existing state point's manifest - files, e.g., to repair them if they got corrupted (Default value = False). + Overwrite any existing state point files, e.g., to repair them if + they got corrupted (Default value = False). Returns ------- Job The job handle. + Raises + ------ + OSError + If the workspace directory cannot be created or any other I/O error + occurs when attempting to save the state point file. + JobsCorruptedError + If the job state point on disk is corrupted. """ try: - self._init(force=force) + # Attempt early exit if the state point file exists and is valid. + try: + statepoint = self.statepoint.load(self.id) + except Exception: + # Any exception means this method cannot exit early. + + # Create the workspace directory if it does not exist. + try: + _mkdir_p(self.workspace()) + except OSError: + logger.error( + "Error occurred while trying to create " + "workspace directory for job '{}'.".format(self.id) + ) + raise + + # The state point save will not overwrite an existing file on + # disk unless force is True, so the subsequent load will catch + # when a preexisting invalid file was present. + self.statepoint.save(force=force) + statepoint = self.statepoint.load(self.id) + + # Update the project's state point cache if the saved file is valid. + self._project._register(self.id, statepoint) except Exception: logger.error( - f"State point manifest file of job '{self.id}' appears to be corrupted." + f"State point file of job '{self.id}' appears to be corrupted." ) raise return self @@ -839,7 +915,9 @@ def __exit__(self, err_type, err_value, tb): def __setstate__(self, state): self.__dict__.update(state) - self.statepoint._parent.jobs.append(self) + # We append to a list of jobs rather than replacing to support + # transparent id updates between shallow copies of a job. + self.statepoint._jobs.append(self) def __deepcopy__(self, memo): cls = self.__class__ diff --git a/signac/contrib/project.py b/signac/contrib/project.py index 757ddf9ea..953a78403 100644 --- a/signac/contrib/project.py +++ b/signac/contrib/project.py @@ -5,6 +5,7 @@ import errno import gzip +import json import logging import os import re @@ -24,10 +25,9 @@ from packaging import version from ..common.config import Config, get_config, load_config -from ..core import json from ..core.h5store import H5StoreManager -from ..core.jsondict import JSONDict from ..sync import sync_projects +from ..synced_collections.backends.collection_json import BufferedJSONAttrDict from ..version import SCHEMA_VERSION, __version__ from .collection import Collection from .errors import ( @@ -507,13 +507,13 @@ def document(self): Returns ------- - :class:`~signac.JSONDict` + :class:`~signac.synced_collections.backends.collection_json.BufferedJSONAttrDict` The project document. """ if self._document is None: fn_doc = os.path.join(self.root_directory(), self.FN_DOCUMENT) - self._document = JSONDict(filename=fn_doc, write_concern=True) + self._document = BufferedJSONAttrDict(filename=fn_doc, write_concern=True) return self._document @document.setter @@ -536,7 +536,7 @@ def doc(self): Returns ------- - :class:`~signac.JSONDict` + :class:`~signac.synced_collections.backends.collection_json.BufferedJSONAttrDict` The project document. """ diff --git a/signac/core/attrdict.py b/signac/core/attrdict.py index 584bf337b..5e4d3f037 100644 --- a/signac/core/attrdict.py +++ b/signac/core/attrdict.py @@ -5,6 +5,10 @@ from .synceddict import _SyncedDict +""" +THIS MODULE IS DEPRECATED! +""" + class SyncedAttrDict(_SyncedDict): """A synced dictionary where (nested) values can be accessed as attributes. diff --git a/signac/core/json.py b/signac/core/json.py index f972a4324..7a8e8f0b4 100644 --- a/signac/core/json.py +++ b/signac/core/json.py @@ -7,6 +7,10 @@ from json.decoder import JSONDecodeError from typing import Any, Dict, Optional +from deprecation import deprecated + +from ..version import __version__ + logger = logging.getLogger(__name__) try: @@ -16,7 +20,12 @@ except ImportError: NUMPY = False +""" +THIS MODULE IS DEPRECATED! +""" + +# this class is deprecated class CustomJSONEncoder(JSONEncoder): """Attempt to JSON-encode objects beyond the default supported types. @@ -25,6 +34,12 @@ class CustomJSONEncoder(JSONEncoder): `_as_dict()` method. """ + @deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, + details="The core.json module bundled with signac is deprecated.", + ) def default(self, o: Any) -> Dict[str, Any]: if NUMPY: if isinstance(o, numpy.number): @@ -39,6 +54,12 @@ def default(self, o: Any) -> Dict[str, Any]: return super().default(o) +@deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, + details="The core.json module bundled with signac is deprecated.", +) def dumps(o: Any, sort_keys: bool = False, indent: Optional[int] = None) -> str: """Convert a JSON-compatible mapping into a string.""" return CustomJSONEncoder(sort_keys=sort_keys, indent=indent).encode(o) diff --git a/signac/core/jsondict.py b/signac/core/jsondict.py index c97bb59af..e63501b1e 100644 --- a/signac/core/jsondict.py +++ b/signac/core/jsondict.py @@ -13,6 +13,9 @@ from copy import copy from tempfile import mkstemp +from deprecation import deprecated + +from ..version import __version__ from . import json from .attrdict import SyncedAttrDict from .errors import Error @@ -29,6 +32,10 @@ _JSONDICT_HASHES = {} _JSONDICT_META = {} +""" +THIS MODULE IS DEPRECATED! +""" + class BufferException(Error): """An exception occurred in buffered mode.""" @@ -91,6 +98,11 @@ def _store_in_buffer(filename, blob, store_hash=False): return True +@deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, +) def flush_all(): """Execute all deferred JSONDict write operations.""" global _BUFFER_LOAD @@ -127,21 +139,41 @@ def flush_all(): _BUFFER_LOAD = 0 +@deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, +) def get_buffer_size(): """Return the current maximum size of the read/write buffer.""" return _BUFFER_SIZE +@deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, +) def get_buffer_load(): """Return the current actual size of the read/write buffer.""" return _BUFFER_LOAD +@deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, +) def in_buffered_mode(): """Return true if in buffered read/write mode.""" return _BUFFERED_MODE > 0 +@deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, +) @contextmanager def buffer_reads_writes(buffer_size=DEFAULT_BUFFER_SIZE, force_write=False): """Enter a global buffer mode for all JSONDict instances. @@ -253,6 +285,11 @@ class JSONDict(SyncedAttrDict): A parent instance of JSONDict or None. """ + @deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, + ) def __init__(self, filename=None, write_concern=False, parent=None): if (filename is None) == (parent is None): raise ValueError( diff --git a/signac/core/synceddict.py b/signac/core/synceddict.py index f9ae74c02..bc81638d9 100644 --- a/signac/core/synceddict.py +++ b/signac/core/synceddict.py @@ -8,6 +8,10 @@ from copy import deepcopy from functools import wraps +from deprecation import deprecated + +from ..version import __version__ + try: import numpy @@ -18,8 +22,17 @@ logger = logging.getLogger(__name__) +""" +THIS MODULE IS DEPRECATED! +""" + class _SyncedList(list): + @deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, + ) def __init__(self, iterable, parent): self._parent = parent super().__init__(iterable) @@ -79,6 +92,11 @@ class _SyncedDict(MutableMapping): VALID_KEY_TYPES = (str, int, bool, type(None)) + @deprecated( + deprecated_in="1.7", + removed_in="2.0", + current_version=__version__, + ) def __init__(self, initialdata=None, parent=None): self._suspend_sync_ = 1 self._parent = parent diff --git a/signac/errors.py b/signac/errors.py index 1ccaa72dd..daae72861 100644 --- a/signac/errors.py +++ b/signac/errors.py @@ -17,6 +17,7 @@ ) from .core.errors import Error from .core.jsondict import BufferedFileError, BufferException +from .synced_collections.errors import InvalidKeyError, KeyTypeError class SyncConflict(Error, RuntimeError): @@ -58,14 +59,6 @@ def __str__(self): return "The synchronization failed, because of a schema conflict." -class InvalidKeyError(ValueError): - """Raised when a user uses a non-conforming key.""" - - -class KeyTypeError(TypeError): - """Raised when a user uses a key of invalid type.""" - - __all__ = [ "AuthenticationError", "BufferException", @@ -80,6 +73,7 @@ class KeyTypeError(TypeError): "IncompatibleSchemaVersion", "InvalidKeyError", "JobsCorruptedError", + "KeyTypeError", "SchemaSyncConflict", "StatepointParsingError", "SyncConflict", diff --git a/signac/synced_collections/__init__.py b/signac/synced_collections/__init__.py new file mode 100644 index 000000000..ffb6b8075 --- /dev/null +++ b/signac/synced_collections/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Define a framework for synchronized objects implementing the Collection interface. + +Synchronization of standard Python data structures with a persistent data store is +important for a number of applications. While tools like `h5py` and `zarr` offer +dict-like interfaces to underlying files, these APIs serve to provide a familiar +wrapper around access patterns specific to these backends. Moreover, these formats +are primarily geared towards the provision of high-performance storage for large +array-like data. Storage of simpler data types, while possible, is generally +more difficult and requires additional work from the user. + +Synced collections fills this gap, introducing a new abstract base class that extends +`collections.abc.Collection` to add transparent synchronization protocols. The package +implements its own versions of standard data structures like dicts and lists, and +it offers support for storing these data structures into various data formats. The +synchronization mechanism is completely transparent to the user; for example, a +`JSONDict` initialized pointing to a particular file can be modified like a normal +dict, and all changes will be automatically persisted to a JSON file. +""" + +from .data_types import SyncedCollection, SyncedDict, SyncedList + +__all__ = ["SyncedCollection", "SyncedDict", "SyncedList"] diff --git a/signac/synced_collections/_caching.py b/signac/synced_collections/_caching.py new file mode 100644 index 000000000..1db0b0c41 --- /dev/null +++ b/signac/synced_collections/_caching.py @@ -0,0 +1,80 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implement the caching feature to SyncedCollection API.""" +import logging +import pickle +import uuid +from collections.abc import MutableMapping + +logger = logging.getLogger(__name__) + + +def get_cache(): + """Return the cache. + + This method returns an instance of :class:`~RedisCache` if a Redis server + is available, or otherwise an instance of :class:`dict` for an in-memory + cache. + + Returns + ------- + cache + An instance of :class:`~_RedisCache` if redis-server is available, + otherwise a dict. + + """ + try: + import redis + + REDIS = True + except ImportError as error: + logger.debug(str(error)) + REDIS = False + if REDIS: + try: + # try to connect to server + cache = redis.Redis() + test_key = str(uuid.uuid4()) + cache.set(test_key, 0) + assert cache.get(test_key) == b"0" # Redis stores data as bytes + cache.delete(test_key) + logger.info("Using Redis cache.") + return _RedisCache(cache) + except (redis.exceptions.ConnectionError, AssertionError) as error: + logger.debug(str(error)) + logger.info("Redis not available.") + return {} + + +class _RedisCache(MutableMapping): + """Redis-based cache. + + Redis restricts the types of data it can handle to bytes, strings, or + numbers, and it always returns responses as bytes. The RedisCache is a + :class:`~collections.abc.MutableMapping` that provides a convenient wrapper + around instances of :class:`redis.Redis`, handling conversions to and from + the appropriate data types. + """ + + def __init__(self, client): + self._client = client + + def __setitem__(self, key, value): + self._client[key] = pickle.dumps(value) + + def __getitem__(self, key): + return pickle.loads(self._client[key]) + + def __delitem__(self, key): + self._client.delete(key) + + def __contains__(self, key): + return key in self._client + + def __iter__(self): + for key in self._client.keys(): + yield key.decode() + + def __len__(self): + return len(self._client.keys()) diff --git a/signac/synced_collections/backends/__init__.py b/signac/synced_collections/backends/__init__.py new file mode 100644 index 000000000..93336d695 --- /dev/null +++ b/signac/synced_collections/backends/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""This subpackage defines supported backends. + +No backends are imported by default. Users should import desired backends as +needed. +""" +from typing import List as _List + +__all__: _List[str] = [] diff --git a/signac/synced_collections/backends/collection_json.py b/signac/synced_collections/backends/collection_json.py new file mode 100644 index 000000000..ce4caa65b --- /dev/null +++ b/signac/synced_collections/backends/collection_json.py @@ -0,0 +1,671 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implements a JSON :class:`~.SyncedCollection` backend.""" + +import errno +import json +import os +import uuid +import warnings +from collections.abc import Mapping, Sequence +from typing import Callable, FrozenSet +from typing import Sequence as Sequence_t + +from .. import SyncedCollection, SyncedDict, SyncedList +from ..buffers.memory_buffered_collection import SharedMemoryFileBufferedCollection +from ..buffers.serialized_file_buffered_collection import ( + SerializedFileBufferedCollection, +) +from ..data_types.attr_dict import AttrDict +from ..errors import InvalidKeyError, KeyTypeError +from ..numpy_utils import ( + _is_atleast_1d_numpy_array, + _is_complex, + _is_numpy_scalar, + _numpy_cache_blocklist, +) +from ..utils import AbstractTypeResolver, SyncedCollectionJSONEncoder +from ..validators import json_format_validator, no_dot_in_key + +""" +There are many classes defined in this file. Most of the definitions are +trivial since logic is largely inherited, but the large number of classes +and the extensive docstrings can be intimidating and make the source +difficult to parse. Section headers like these are used to organize the +code to reduce this barrier. +""" + + +# TODO: This method should be removed in signac 2.0. +def _str_key(key): + VALID_KEY_TYPES = (str, int, bool, type(None)) + + if not isinstance(key, VALID_KEY_TYPES): + raise KeyTypeError( + f"Mapping keys must be str, int, bool or None, not {type(key).__name__}" + ) + elif not isinstance(key, str): + warnings.warn( + f"Use of {type(key).__name__} as key is deprecated " + "and will be removed in version 2.0", + DeprecationWarning, + ) + key = str(key) + return key + + +# TODO: This method should be removed in signac 2.0. +def _convert_key_to_str(data): + """Recursively convert non-string keys to strings in dicts. + + This method supports :class:`collections.abc.Sequence` or + :class:`collections.abc.Mapping` types as inputs, and recursively + searches for any entries in :class:`collections.abc.Mapping` types where + the key is not a string. This functionality is added for backwards + compatibility with legacy behavior in signac, which allowed integer keys + for dicts. These inputs were silently converted to string keys and stored + since JSON does not support integer keys. This behavior is deprecated and + will become an error in signac 2.0. + + Note for developers: this method is designed for use as a validator in the + synced collections framework, but due to the backwards compatibility requirement + it violates the general behavior of validators by modifying the data in place. + This behavior can be removed in signac 2.0 once non-str keys become an error. + """ + if isinstance(data, dict): + # Explicitly call `list(keys)` to get a fixed list of keys to avoid + # running into issues with iterating over a DictKeys view while + # modifying the dict at the same time. + for key in list(data): + _convert_key_to_str(data[key]) + data[_str_key(key)] = data.pop(key) + elif isinstance(data, list): + for i, value in enumerate(data): + _convert_key_to_str(value) + + +_json_attr_dict_validator_type_resolver = AbstractTypeResolver( + { + # We identify >0d numpy arrays as sequences for validation purposes. + "SEQUENCE": lambda obj: (isinstance(obj, Sequence) and not isinstance(obj, str)) + or _is_atleast_1d_numpy_array(obj), + "NUMPY": lambda obj: _is_numpy_scalar(obj), + "BASE": lambda obj: isinstance(obj, (str, int, float, bool, type(None))), + "MAPPING": lambda obj: isinstance(obj, Mapping), + }, + cache_blocklist=_numpy_cache_blocklist, +) + + +def json_attr_dict_validator(data): + """Validate data for JSONAttrDict. + + This validator combines the logic from the following validators into one to + make validation more efficient: + + This validator combines the following logic: + - JSON format validation + - Ensuring no dots are present in string keys + - Converting non-str keys to strings. This is a backwards compatibility + layer that will be removed in signac 2.0. + + Parameters + ---------- + data + Data to validate. + + Raises + ------ + KeyTypeError + If key data type is not supported. + TypeError + If the data type of ``data`` is not supported. + + """ + switch_type = _json_attr_dict_validator_type_resolver.get_type(data) + + if switch_type == "BASE": + return + elif switch_type == "MAPPING": + # Explicitly call `list(keys)` to get a fixed list of keys to avoid + # running into issues with iterating over a DictKeys view while + # modifying the dict at the same time. Inside the loop, we: + # 1) validate the key, converting to string if necessary + # 2) pop and validate the value + # 3) reassign the value to the (possibly converted) key + for key in list(data): + json_attr_dict_validator(data[key]) + if isinstance(key, str): + if "." in key: + raise InvalidKeyError( + f"Mapping keys may not contain dots ('.'): {key}." + ) + elif isinstance(key, (int, bool, type(None))): + # TODO: Remove this branch in signac 2.0. + warnings.warn( + f"Use of {type(key).__name__} as key is deprecated " + "and will be removed in version 2.0.", + DeprecationWarning, + ) + data[str(key)] = data.pop(key) + else: + raise KeyTypeError( + f"Mapping keys must be str, int, bool or None, not {type(key).__name__}." + ) + elif switch_type == "SEQUENCE": + for value in data: + json_attr_dict_validator(value) + elif switch_type == "NUMPY": + if _is_numpy_scalar(data.item()): + raise TypeError("NumPy extended precision types are not JSON serializable.") + elif _is_complex(data): + raise TypeError("Complex numbers are not JSON serializable.") + else: + raise TypeError( + f"Object of type {type(data).__name__} is not JSON serializable." + ) + + +""" +Here we define the main JSONCollection class that encapsulates most of the +logic for reading from and writing to JSON files. The remaining classes in +this file inherit from these classes to add features like buffering or +attribute-based dictionary access, each with a different backend name for +correct resolution of nested SyncedCollection types. +""" + + +class JSONCollection(SyncedCollection): + r"""A :class:`~.SyncedCollection` that synchronizes with a JSON file. + + This collection implements synchronization by reading and writing the associated + JSON file in its entirety for every read/write operation. This backend is a good + choice for maximum accessibility and transparency since all data is immediately + accessible in the form of a text file with no additional tooling, but is + likely a poor choice for high performance applications. + + **Thread safety** + + The :class:`JSONCollection` is thread-safe. To make these collections safe, the + ``write_concern`` flag is ignored in multithreaded execution, and the + write is **always** performed via a write to temporary file followed by a + replacement of the original file. The file replacement operation uses + :func:`os.replace`, which is guaranteed to be atomic by the Python standard. + + Parameters + ---------- + filename : str + The filename of the associated JSON file on disk. + write_concern : bool, optional + Ensure file consistency by writing changes back to a temporary file + first, before replacing the original file (Default value = False). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + """ + + _backend = __name__ # type: ignore + _supports_threading = True + + # The order in which these validators are added is important, because + # validators are called in sequence and _convert_key_to_str will ensure that + # valid non-str keys are converted to strings before json_format_validator is + # called. This ordering is an implementation detail that we should not rely on + # in the future, however, the _convert_key_to_str validator will be removed in + # signac 2.0 so this is OK (that validator is modifying the data in place, + # which is unsupported behavior that will be removed in signac 2.0 as well). + _validators: Sequence_t[Callable] = (_convert_key_to_str, json_format_validator) + + def __init__(self, filename=None, write_concern=False, *args, **kwargs): + # The `_filename` attribute _must_ be defined prior to calling the + # superclass constructors because the filename defines the `_lock_id` + # used to uniquely identify thread locks for this collection. + self._filename = filename + super().__init__(*args, **kwargs) + self._write_concern = write_concern + + def _load_from_resource(self): + """Load the data from a JSON file. + + Returns + ------- + Collection or None + An equivalent unsynced collection satisfying :meth:`is_base_type` that + contains the data in the JSON file. Will return None if the file does + not exist. + + """ + try: + with open(self._filename, "rb") as file: + blob = file.read() + return json.loads(blob) + except OSError as error: + if error.errno == errno.ENOENT: + return None + else: + raise + + def _save_to_resource(self): + """Write the data to JSON file.""" + # Serialize data + blob = json.dumps(self, cls=SyncedCollectionJSONEncoder).encode() + # When write_concern flag is set, we write the data into dummy file and then + # replace that file with original file. We also enable this mode + # irrespective of the write_concern flag if we're running in + # multithreaded mode. + if self._write_concern or type(self)._threading_support_is_active: + dirname, filename = os.path.split(self._filename) + fn_tmp = os.path.join(dirname, f"._{uuid.uuid4()}_{filename}") + with open(fn_tmp, "wb") as tmpfile: + tmpfile.write(blob) + os.replace(fn_tmp, self._filename) + else: + with open(self._filename, "wb") as file: + file.write(blob) + + @property + def filename(self): + """str: The name of the associated JSON file on disk.""" + return self._filename + + @property + def _lock_id(self): + return self._filename + + +# These are the common protected keys used by all JSONDict types. +_JSONDICT_PROTECTED_KEYS = frozenset( + ( + # These are all protected keys that are inherited from data type classes. + "_data", + "_name", + "_suspend_sync_", + "_load", + "_sync", + "_root", + "_validators", + "_all_validators", + "_load_and_save", + "_suspend_sync", + "_supports_threading", + "_LoadSaveType", + "registry", + # These keys are specific to the JSON backend. + "_filename", + "_write_concern", + ) +) + + +class JSONDict(JSONCollection, SyncedDict): + r"""A dict-like data structure that synchronizes with a persistent JSON file. + + Examples + -------- + >>> doc = JSONDict('data.json', write_concern=True) + >>> doc['foo'] = "bar" + >>> assert doc['foo'] == "bar" + >>> assert 'foo' in doc + >>> del doc['foo'] + >>> doc['foo'] = dict(bar=True) + >>> doc + {'foo': {'bar': True}} + + Parameters + ---------- + filename : str, optional + The filename of the associated JSON file on disk (Default value = None). + write_concern : bool, optional + Ensure file consistency by writing changes back to a temporary file + first, before replacing the original file (Default value = False). + data : :class:`collections.abc.Mapping`, optional + The initial data passed to :class:`JSONDict`. If ``None``, defaults to + ``{}`` (Default value = None). + parent : JSONCollection, optional + A parent instance of :class:`JSONCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`JSONDict` object behaves like a :class:`dict`, there are important + distinctions to remember. In particular, because operations are reflected + as changes to an underlying file, copying (even deep copying) a :class:`JSONDict` + instance may exhibit unexpected behavior. If a true copy is required, you + should use the call operator to get a dictionary representation, and if + necessary construct a new :class:`JSONDict` instance. + + """ + + _PROTECTED_KEYS: FrozenSet[str] = _JSONDICT_PROTECTED_KEYS + + def __init__( + self, + filename=None, + write_concern=False, + data=None, + parent=None, + *args, + **kwargs, + ): + super().__init__( + filename=filename, + write_concern=write_concern, + data=data, + parent=parent, + *args, + **kwargs, + ) + + +class JSONList(JSONCollection, SyncedList): + r"""A list-like data structure that synchronizes with a persistent JSON file. + + Only non-string sequences are supported by this class. + + Examples + -------- + >>> synced_list = JSONList('data.json', write_concern=True) + >>> synced_list.append("bar") + >>> assert synced_list[0] == "bar" + >>> assert len(synced_list) == 1 + >>> del synced_list[0] + + Parameters + ---------- + filename : str, optional + The filename of the associated JSON file on disk (Default value = None). + write_concern : bool, optional + Ensure file consistency by writing changes back to a temporary file + first, before replacing the original file (Default value = None). + data : non-str :class:`collections.abc.Sequence`, optional + The initial data passed to :class:`JSONList `. If ``None``, defaults to + ``[]`` (Default value = None). + parent : JSONCollection, optional + A parent instance of :class:`JSONCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`JSONList` object behaves like a :class:`list`, there are important + distinctions to remember. In particular, because operations are reflected + as changes to an underlying file, copying (even deep copying) a :class:`JSONList` + instance may exhibit unexpected behavior. If a true copy is required, you + should use the call operator to get a dictionary representation, and if + necessary construct a new :class:`JSONList` instance. + + """ + + def __init__( + self, + filename=None, + write_concern=False, + data=None, + parent=None, + *args, + **kwargs, + ): + super().__init__( + filename=filename, + write_concern=write_concern, + data=data, + parent=parent, + *args, + **kwargs, + ) + + +""" +Here we define the BufferedJSONCollection class and its data type +subclasses, which augment the JSONCollection with a serialized in-memory +buffer for improved performance. +""" + + +class BufferedJSONCollection(SerializedFileBufferedCollection, JSONCollection): + """A :class:`JSONCollection` that supports I/O buffering. + + This class implements the buffer protocol defined by + :class:`~.BufferedCollection`. The concrete implementation of buffering + behavior is defined by the :class:`~.SerializedFileBufferedCollection`. + """ + + _backend = __name__ + ".buffered" # type: ignore + + +# These are the keys common to buffer backends. +_BUFFERED_PROTECTED_KEYS = frozenset( + ( + "buffered", + "_is_buffered", + "_buffer_lock", + "_buffer_context", + "_buffered_collections", + ) +) + + +class BufferedJSONDict(BufferedJSONCollection, SyncedDict): + """A buffered :class:`JSONDict`.""" + + _PROTECTED_KEYS: FrozenSet[str] = ( + _JSONDICT_PROTECTED_KEYS | _BUFFERED_PROTECTED_KEYS + ) + + def __init__( + self, + filename=None, + write_concern=False, + data=None, + parent=None, + *args, + **kwargs, + ): + super().__init__( + filename=filename, + write_concern=write_concern, + data=data, + parent=parent, + *args, + **kwargs, + ) + + +class BufferedJSONList(BufferedJSONCollection, SyncedList): + """A buffered :class:`JSONList`.""" + + def __init__( + self, + filename=None, + write_concern=False, + data=None, + parent=None, + *args, + **kwargs, + ): + super().__init__( + filename=filename, + write_concern=write_concern, + data=data, + parent=parent, + *args, + **kwargs, + ) + + +""" +Here we define the MemoryBufferedJSONCollection class and its data type +subclasses, which augment the JSONCollection with a serialized in-memory +buffer for improved performance. +""" + + +class MemoryBufferedJSONCollection(SharedMemoryFileBufferedCollection, JSONCollection): + """A :class:`JSONCollection` that supports I/O buffering. + + This class implements the buffer protocol defined by :class:`~.BufferedCollection`. + The concrete implementation of buffering behavior is defined by the + :class:`~.SharedMemoryFileBufferedCollection`. + """ + + _backend = __name__ + ".memory_buffered" # type: ignore + + +class MemoryBufferedJSONDict(MemoryBufferedJSONCollection, SyncedDict): + """A buffered :class:`JSONDict`.""" + + _PROTECTED_KEYS: FrozenSet[str] = ( + _JSONDICT_PROTECTED_KEYS | _BUFFERED_PROTECTED_KEYS + ) + + def __init__( + self, + filename=None, + write_concern=False, + data=None, + parent=None, + *args, + **kwargs, + ): + super().__init__( + filename=filename, + write_concern=write_concern, + data=data, + parent=parent, + *args, + **kwargs, + ) + + +class MemoryBufferedJSONList(MemoryBufferedJSONCollection, SyncedList): + """A buffered :class:`JSONList`.""" + + def __init__( + self, + filename=None, + write_concern=False, + data=None, + parent=None, + *args, + **kwargs, + ): + super().__init__( + filename=filename, + write_concern=write_concern, + data=data, + parent=parent, + *args, + **kwargs, + ) + + +""" +Here we define various extensions of the above classes that add +attribute-based access to dictionaries. Although list behavior is not +modified in any way by these, they still require separate classes with the +right backend so that nested classes are created appropriately. +""" + + +class JSONAttrDict(JSONDict, AttrDict): + r"""A dict-like data structure that synchronizes with a persistent JSON file. + + Unlike :class:`JSONAttrDict`, this class also supports attribute-based access to + dictionary contents, e.g. ``doc.foo == doc['foo']``. + + Examples + -------- + >>> doc = JSONAttrDict('data.json', write_concern=True) + >>> doc['foo'] = "bar" + >>> assert doc.foo == doc['foo'] == "bar" + >>> assert 'foo' in doc + >>> del doc['foo'] + >>> doc['foo'] = dict(bar=True) + >>> doc + {'foo': {'bar': True}} + >>> doc.foo.bar = False + >>> doc + {'foo': {'bar': False}} + + Parameters + ---------- + filename : str, optional + The filename of the associated JSON file on disk (Default value = None). + write_concern : bool, optional + Ensure file consistency by writing changes back to a temporary file + first, before replacing the original file (Default value = False). + data : :class:`collections.abc.Mapping`, optional + The initial data passed to :class:`JSONAttrDict`. If ``None``, defaults to + ``{}`` (Default value = None). + parent : JSONCollection, optional + A parent instance of :class:`JSONCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`JSONAttrDict` object behaves like a :class:`dict`, there are important + distinctions to remember. In particular, because operations are reflected + as changes to an underlying file, copying (even deep copying) a :class:`JSONAttrDict` + instance may exhibit unexpected behavior. If a true copy is required, you + should use the call operator to get a dictionary representation, and if + necessary construct a new :class:`JSONAttrDict` instance. + + """ + + _backend = __name__ + ".attr" # type: ignore + # Define the validators in case subclasses want to inherit the correct + # behavior, but define _all_validators for performance of this class. + _validators = (no_dot_in_key,) + _all_validators = (json_attr_dict_validator,) + + +class JSONAttrList(JSONList): + """A :class:`JSONList` whose dict-like children will be of type :class:`JSONAttrDict`.""" + + _backend = __name__ + ".attr" # type: ignore + + +class BufferedJSONAttrDict(BufferedJSONDict, AttrDict): + """A buffered :class:`JSONAttrDict`.""" + + _backend = __name__ + ".buffered_attr" # type: ignore + # Define the validators in case subclasses want to inherit the correct + # behavior, but define _all_validators for performance of this class. + _validators = (no_dot_in_key,) + _all_validators = (json_attr_dict_validator,) + + +class BufferedJSONAttrList(BufferedJSONList): + """A :class:`BufferedJSONList` whose dict-like children will be of type :class:`BufferedJSONAttrDict`.""" # noqa: E501 + + _backend = __name__ + ".buffered_attr" # type: ignore + + +class MemoryBufferedJSONAttrDict(MemoryBufferedJSONDict, AttrDict): + """A buffered :class:`JSONAttrDict`.""" + + _backend = __name__ + ".memory_buffered_attr" # type: ignore + # Define the validators in case subclasses want to inherit the correct + # behavior, but define _all_validators for performance of this class. + _validators = (no_dot_in_key,) + _all_validators = (json_attr_dict_validator,) + + +class MemoryBufferedJSONAttrList(MemoryBufferedJSONList): + """A :class:`MemoryBufferedJSONList` whose dict-like children will be of type :class:`MemoryBufferedJSONAttrDict`.""" # noqa: E501 + + _backend = __name__ + ".memory_buffered_attr" # type: ignore diff --git a/signac/synced_collections/backends/collection_mongodb.py b/signac/synced_collections/backends/collection_mongodb.py new file mode 100644 index 000000000..78e5c7536 --- /dev/null +++ b/signac/synced_collections/backends/collection_mongodb.py @@ -0,0 +1,215 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implements a MongoDB :class:`~.SyncedCollection` backend.""" +from .. import SyncedCollection, SyncedDict, SyncedList +from ..validators import json_format_validator, require_string_key + +try: + import bson + + MONGO = True +except ImportError: + MONGO = False + + +class MongoDBCollection(SyncedCollection): + r"""A :class:`~.SyncedCollection` that synchronizes with a MongoDB document. + + In MongoDB, a database is composed of multiple MongoDB **collections**, which + are analogous to tables in SQL databases but do not enforce a schema like + in relational databases. In turn, collections are composed of **documents**, + which are analogous to rows in a table but are much more flexible, storing + any valid JSON object in a JSON-like encoded format known as BSON + ("binary JSON"). + + Each :class:`~.MongoDBCollection` can be represented as a MongoDB document, + so this backend stores the :class:`~.MongoDBCollection` as a single + document within the collection provided by the user. The document is + identified by a unique key provided by the user. + + **Thread safety** + + The :class:`MongoDBCollection` is not thread-safe. + + Parameters + ---------- + collection : :class:`pymongo.collection.Collection` + The MongoDB client in which to store data. + uid : dict + The unique key-value mapping added to the data and stored in the document + so that it is uniquely identifiable in the MongoDB collection. The key + "data" is reserved and may not be part of this uid. + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + The user is responsible for providing a unique id such that there are no + possible collisions between different :class:`~.MongoDBCollection` instances + stored in the same MongoDB collection. Failure to do so may result in data + corruption if multiple documents are found to be apparently associated with + a given ``uid``. + + """ + + _backend = __name__ # type: ignore + + # MongoDB uses BSON, which is not exactly JSON but is close enough that + # JSON-validation is reasonably appropriate. we could generalize this to do + # proper BSON validation if we find that the discrepancies (for instance, the + # supported integer data types differ) are too severe. + _validators = (json_format_validator,) + + def __init__(self, collection=None, uid=None, parent=None, *args, **kwargs): + super().__init__(parent=parent, **kwargs) + if not MONGO: + raise RuntimeError( + "The PyMongo package must be installed to use the MongoDBCollection." + ) + + self._collection = collection + if uid is not None and "data" in uid: + raise ValueError("The key 'data' may not be part of the uid.") + self._uid = uid + + def _load_from_resource(self): + """Load the data from a MongoDB document. + + Returns + ------- + Collection or None + An equivalent unsynced collection satisfying :meth:`~.is_base_type` that + contains the data in the MongoDB database. Will return None if no data + was found in the database. + + """ + blob = self._collection.find_one(self._uid) + return blob["data"] if blob is not None else None + + def _save_to_resource(self): + """Write the data to a MongoDB document.""" + data = self._to_base() + data_to_insert = {**self._uid, "data": data} + try: + self._collection.replace_one(self._uid, data_to_insert, True) + except bson.errors.InvalidDocument as err: + raise TypeError(str(err)) + + @property + def collection(self): + """pymongo.collection.Collection: Get the collection being synced to.""" + return self._collection + + @property + def uid(self): # noqa: D401 + """dict: Get the unique mapping used to identify this collection.""" + return self._uid + + def __deepcopy__(self, memo): + # The underlying MongoDB collection cannot be deepcopied. + raise TypeError("MongoDBCollection does not support deepcopying.") + + +class MongoDBDict(MongoDBCollection, SyncedDict): + r"""A dict-like data structure that synchronizes with a document in a MongoDB collection. + + Examples + -------- + >>> doc = MongoDBDict('data') + >>> doc['foo'] = "bar" + >>> assert doc['foo'] == "bar" + >>> assert 'foo' in doc + >>> del doc['foo'] + >>> doc['foo'] = dict(bar=True) + >>> doc + {'foo': {'bar': True}} + + Parameters + ---------- + collection : pymongo.collection.Collection, optional + A :class:`pymongo.collection.Collection` instance (Default value = None). + uid : dict, optional + The unique key-value mapping identifying the collection (Default value = None). + data : non-str :class:`collections.abc.Mapping`, optional + The initial data passed to :class:`MongoDBDict`. If ``None``, defaults to + ``{}`` (Default value = None). + parent : MongoDBCollection, optional + A parent instance of :class:`MongoDBCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`MongoDBDict` object behaves like a :class:`dict`, there are + important distinctions to remember. In particular, because operations are + reflected as changes to an underlying database, copying a + :class:`MongoDBDict` instance may exhibit unexpected behavior. If a true + copy is required, you should use the call operator to get a dictionary + representation, and if necessary construct a new :class:`MongoDBDict` + instance. + + """ + + _validators = (require_string_key,) + + def __init__( + self, collection=None, uid=None, data=None, parent=None, *args, **kwargs + ): + super().__init__( + collection=collection, uid=uid, data=data, parent=parent, *args, **kwargs + ) + + +class MongoDBList(MongoDBCollection, SyncedList): + r"""A list-like data structure that synchronizes with a document in a MongoDB collection. + + Only non-string sequences are supported by this class. + + Examples + -------- + >>> synced_list = MongoDBList('data') + >>> synced_list.append("bar") + >>> assert synced_list[0] == "bar" + >>> assert len(synced_list) == 1 + >>> del synced_list[0] + + Parameters + ---------- + collection : pymongo.collection.Collection, optional + A :class:`pymongo.collection.Collection` instance (Default value = None). + uid : dict, optional + The unique key-value mapping identifying the collection (Default value = None). + data : non-str :class:`collections.abc.Sequence`, optional + The initial data passed to :class:`MongoDBList`. If ``None``, defaults to + ``[]`` (Default value = None). + parent : MongoDBCollection, optional + A parent instance of :class:`MongoDBCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`MongoDBList` object behaves like a :class:`list`, there are important + distinctions to remember. In particular, because operations are reflected + as changes to an underlying database, copying a :class:`MongoDBList` instance may + exhibit unexpected behavior. If a true copy is required, you should use the + call operator to get a dictionary representation, and if necessary + construct a new :class:`MongoDBList` instance. + + """ + + def __init__( + self, collection=None, uid=None, data=None, parent=None, *args, **kwargs + ): + super().__init__( + collection=collection, uid=uid, data=data, parent=parent, *args, **kwargs + ) diff --git a/signac/synced_collections/backends/collection_redis.py b/signac/synced_collections/backends/collection_redis.py new file mode 100644 index 000000000..fe43ad20a --- /dev/null +++ b/signac/synced_collections/backends/collection_redis.py @@ -0,0 +1,174 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implements a Redis :class:`~.SyncedCollection` backend.""" +import json + +from .. import SyncedCollection, SyncedDict, SyncedList +from ..utils import SyncedCollectionJSONEncoder +from ..validators import json_format_validator, require_string_key + + +class RedisCollection(SyncedCollection): + r"""A :class:`~.SyncedCollection` that synchronizes with a Redis database. + + This backend stores data in Redis by associating it with the provided key. + + **Thread safety** + + The :class:`RedisCollection` is not thread-safe. + + Parameters + ---------- + client : redis.Redis + The Redis client used to persist data. + key : str + The key associated with this collection in the Redis database. + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + """ + + _backend = __name__ # type: ignore + + # Redis collection relies on JSON-serialization for the data. + _validators = (json_format_validator,) + + def __init__(self, client=None, key=None, *args, **kwargs): + super().__init__(**kwargs) + self._client = client + self._key = key + + def _load_from_resource(self): + """Load the data from a Redis database. + + Returns + ------- + Collection or None + An equivalent unsynced collection satisfying :meth:`~.is_base_type` that + contains the data in the Redis database. Will return None if no data + was found in the Redis database. + + """ + blob = self._client.get(self._key) + return None if blob is None else json.loads(blob) + + def _save_to_resource(self): + """Write the data to a Redis database.""" + self._client.set( + self._key, json.dumps(self, cls=SyncedCollectionJSONEncoder).encode() + ) + + @property + def client(self): + """`redis.Redis`: The Redis client used to store the data.""" + return self._client + + @property + def key(self): + """str: The key associated with this collection stored in Redis.""" + return self._key + + def __deepcopy__(self, memo): + # The underlying Redis client cannot be deepcopied. + raise TypeError("RedisCollection does not support deepcopying.") + + +class RedisDict(RedisCollection, SyncedDict): + r"""A dict-like data structure that synchronizes with a persistent Redis database. + + Examples + -------- + >>> doc = RedisDict('data') + >>> doc['foo'] = "bar" + >>> assert doc['foo'] == "bar" + >>> assert 'foo' in doc + >>> del doc['foo'] + >>> doc['foo'] = dict(bar=True) + >>> doc + {'foo': {'bar': True}} + + Parameters + ---------- + client : redis.Redis, optional + A redis client (Default value = None). + key : str, optional + The key of the collection (Default value = None). + data : :class:`collections.abc.Mapping`, optional + The initial data passed to :class:`RedisDict`. If ``None``, defaults to + ``{}`` (Default value = None). + parent : RedisCollection, optional + A parent instance of :class:`RedisCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`RedisDict` object behaves like a :class:`dict`, there are important + distinctions to remember. In particular, because operations are reflected + as changes to an underlying database, copying a :class:`RedisDict` instance may + exhibit unexpected behavior. If a true copy is required, you should use the + call operator to get a dictionary representation, and if necessary + construct a new :class:`RedisDict` instance. + + """ + + _validators = (require_string_key,) + + def __init__(self, client=None, key=None, data=None, parent=None, *args, **kwargs): + super().__init__( + client=client, key=key, data=data, parent=parent, *args, **kwargs + ) + + +class RedisList(RedisCollection, SyncedList): + r"""A list-like data structure that synchronizes with a persistent Redis database. + + Only non-string sequences are supported by this class. + + Examples + -------- + >>> synced_list = RedisList('data') + >>> synced_list.append("bar") + >>> assert synced_list[0] == "bar" + >>> assert len(synced_list) == 1 + >>> del synced_list[0] + + + Parameters + ---------- + client : redis.Redis, optional + A Redis client (Default value = None). + key : str, optional + The key of the collection (Default value = None). + data : non-str :class:`collections.abc.Sequence`, optional + The initial data passed to :class:`RedisList`. If ``None``, defaults to + ``[]`` (Default value = None). + parent : RedisCollection, optional + A parent instance of :class:`RedisCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`RedisList` object behaves like a :class:`list`, there are + important distinctions to remember. In particular, because operations are + reflected as changes to an underlying database, copying a + :class:`RedisList` instance may exhibit unexpected behavior. If a true copy + is required, you should use the call operator to get a dictionary + representation, and if necessary construct a new :class:`RedisList` + instance. + """ + + def __init__(self, client=None, key=None, data=None, parent=None, *args, **kwargs): + super().__init__( + client=client, key=key, data=data, parent=parent, *args, **kwargs + ) diff --git a/signac/synced_collections/backends/collection_zarr.py b/signac/synced_collections/backends/collection_zarr.py new file mode 100644 index 000000000..53b2f3edc --- /dev/null +++ b/signac/synced_collections/backends/collection_zarr.py @@ -0,0 +1,217 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implements a Zarr :class:`~.SyncedCollection` backend.""" +from copy import deepcopy + +from .. import SyncedCollection, SyncedDict, SyncedList +from ..validators import require_string_key + +try: + import numcodecs + + ZARR = True +except ImportError: + ZARR = False + + +class ZarrCollection(SyncedCollection): + r"""A :class:`~.SyncedCollection` that synchronizes with a Zarr group. + + Since Zarr is designed for storage of array-like data, this backend implements + synchronization by storing the collection in a 1-element object array. The user + provides the group within which to store the data and the name of the data in + the group. + + **Thread safety** + + The :class:`ZarrCollection` is not thread-safe. + + Parameters + ---------- + group : zarr.hierarchy.Group + The Zarr group in which to store data. + name : str + The name under which this collection is stored in the Zarr group. + codec : numcodecs.abc.Codec + The encoding mechanism for the data. If not provided, defaults to JSON + encoding (Default value: None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + """ + + _backend = __name__ # type: ignore + + def __init__(self, group=None, name=None, codec=None, *args, **kwargs): + if not ZARR: + raise RuntimeError( + "The Zarr package must be installed to use the ZarrCollection." + ) + + super().__init__(**kwargs) + self._group = group + self._name = name + self._object_codec = numcodecs.JSON() if codec is None else codec + + def _load_from_resource(self): + """Load the data from the Zarr group. + + Returns + ------- + Collection or None + An equivalent unsynced collection satisfying :meth:`~.is_base_type` that + contains the data in the Zarr group. Will return None if associated + data is not found in the Zarr group. + + """ + try: + return self._group[self._name][0] + except KeyError: + return None + + def _save_to_resource(self): + """Write the data to Zarr group.""" + data = self._to_base() + dataset = self._group.require_dataset( + self._name, + overwrite=True, + shape=1, + dtype="object", + object_codec=self._object_codec, + ) + dataset[0] = data + + def __deepcopy__(self, memo): + if self._root is not None: + return type(self)( + group=None, + name=None, + data=self._to_base(), + parent=deepcopy(self._root, memo), + ) + else: + return type(self)( + group=deepcopy(self._group, memo), + name=self._name, + data=None, + parent=None, + ) + + @property + def codec(self): + """numcodecs.abc.Codec: The encoding method used for the data.""" + return self._object_codec + + @codec.setter + def codec(self, new_codec): + self._object_codec = new_codec + + @property + def group(self): + """zarr.hierarchy.Group: The Zarr group storing the data.""" + return self._group + + @property + def name(self): + """str: The name of this data in the Zarr group.""" + return self._name + + +class ZarrDict(ZarrCollection, SyncedDict): + r"""A dict-like data structure that synchronizes with a Zarr group. + + Examples + -------- + >>> doc = ZarrDict('data') + >>> doc['foo'] = "bar" + >>> assert doc['foo'] == "bar" + >>> assert 'foo' in doc + >>> del doc['foo'] + >>> doc['foo'] = dict(bar=True) + >>> doc + {'foo': {'bar': True}} + + Parameters + ---------- + group : zarr.hierarchy.Group, optional + The group in which to store data (Default value = None). + name : str, optional + The name of the collection (Default value = None). + data : :class:`collections.abc.Mapping`, optional + The initial data passed to :class:`ZarrDict`. If ``None``, defaults to + ``{}`` (Default value = None). + parent : ZarrCollection, optional + A parent instance of :class:`ZarrCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`ZarrDict` object behaves like a :class:`dict`, there are important + distinctions to remember. In particular, because operations are reflected + as changes to an underlying database, copying (even deep copying) a + :class:`ZarrDict` instance may exhibit unexpected behavior. If a true copy is + required, you should use the call operator to get a dictionary + representation, and if necessary construct a new :class:`ZarrDict` instance. + + """ + + _validators = (require_string_key,) + + def __init__(self, group=None, name=None, data=None, parent=None, *args, **kwargs): + super().__init__( + group=group, name=name, data=data, parent=parent, *args, **kwargs + ) + + +class ZarrList(ZarrCollection, SyncedList): + r"""A list-like data structure that synchronizes with a Zarr group. + + Only non-string sequences are supported by this class. + + Examples + -------- + >>> synced_list = ZarrList('data') + >>> synced_list.append("bar") + >>> assert synced_list[0] == "bar" + >>> assert len(synced_list) == 1 + >>> del synced_list[0] + + Parameters + ---------- + group : zarr.hierarchy.Group, optional + The group in which to store data (Default value = None). + name : str, optional + The name of the collection (Default value = None). + data : non-str :class:`collections.abc.Sequence`, optional + The initial data passed to :class:`ZarrList`. If ``None``, defaults to + ``[]`` (Default value = None). + parent : ZarrCollection, optional + A parent instance of :class:`ZarrCollection` or ``None``. If ``None``, + the collection owns its own data (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`ZarrList` object behaves like a :class:`list`, there are important + distinctions to remember. In particular, because operations are reflected + as changes to an underlying database, copying (even deep copying) a + :class:`ZarrList` instance may exhibit unexpected behavior. If a true copy is + required, you should use the call operator to get a dictionary + representation, and if necessary construct a new :class:`ZarrList` instance. + + """ + + def __init__(self, group=None, name=None, data=None, parent=None, *args, **kwargs): + super().__init__( + group=group, name=name, data=data, parent=parent, *args, **kwargs + ) diff --git a/signac/synced_collections/buffers/__init__.py b/signac/synced_collections/buffers/__init__.py new file mode 100644 index 000000000..4749ea527 --- /dev/null +++ b/signac/synced_collections/buffers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Defines the buffering protocol for synced collections. + +In addition to defining the buffering protocol for synced collections in +:class:`~.BufferedCollection`, this subpackage also defines a number of +supported buffering implementations. No buffers are imported by default. Users +should import desired buffers as needed. +""" +from typing import List as _List + +__all__: _List[str] = [] diff --git a/signac/synced_collections/buffers/buffered_collection.py b/signac/synced_collections/buffers/buffered_collection.py new file mode 100644 index 000000000..5300af49a --- /dev/null +++ b/signac/synced_collections/buffers/buffered_collection.py @@ -0,0 +1,181 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Defines a buffering protocol for :class:`~.SyncedCollection` objects. + +Depending on the choice of backend, synchronization may be an expensive process. +In that case, it can be helpful to allow many in-memory modifications to occur +before any synchronization is attempted. Since many collections could be pointing +to the same underlying resource, maintaining proper data coherency across different +instances requires careful consideration of how the data is stored. The appropriate +buffering methods can differ for different backends; as a result, the basic +interface simply lays out the API for buffering and leaves implementation +details for specific backends to handle. Judicious use of buffering can +dramatically speed up code paths that might otherwise involve, for instance, +heavy I/O. The specific buffering mechanism must be implemented by each backend +since it depends on the nature of the underlying data format. + +All buffered collections expose a local context manager for buffering. In +addition, each backend exposes a context manager +:meth:`BufferedCollection.buffer_backend` that indicates to all buffered +collections of that backend that they should enter buffered mode. These context +managers may be nested freely, and buffer flushes will occur when all such +managers have been exited. + +.. code-block:: + + with collection1.buffered: + with type(collection1).buffer_backend: + collection2['foo'] = 1 + collection1['bar'] = 1 + # collection2 will flush when this context exits. + + # This operation will write straight to the backend. + collection2['bar'] = 2 + + # collection1 will flush when this context exits. +""" + +import logging +from inspect import isabstract + +from .. import SyncedCollection +from ..utils import _CounterFuncContext + +logger = logging.getLogger(__name__) + + +class BufferedCollection(SyncedCollection): + """A :class:`~.SyncedCollection` defining an interface for buffering. + + **The default behavior of this class is not to buffer.** This class simply + defines an appropriate interface for buffering behavior so that client code + can rely on these methods existing, e.g. to be able to do things like ``with + collection.buffered...``. This feature allows client code to indicate to the + collection when it is safe to buffer reads and writes, which usually means + guaranteeing that the synchronization destination (e.g. an underlying file + or database entry) will not be modified by other processes concurrently + with the set of operations within the buffered block. However, in the + default case the result of this will be a no-op and all data will be + immediately synchronized with the backend. + + The BufferedCollection overrides the :meth:`~._load` and + :meth:`~._save` methods to check whether buffering is enabled + or not. If not, the behavior is identical to the parent class. When in buffered + mode, however, the BufferedCollection introduces two additional hooks that + can be overridden by subclasses to control how the collection behaves while buffered: + + - :meth:`~._load_from_buffer`: Loads data while in buffered mode and returns + it in an object satisfying + :meth:`~signac.synced_collections.data_types.synced_collection.SyncedCollection.is_base_type`. + The default behavior is to simply call + :meth:`~._load_from_resource`. + - :meth:`~._save_to_buffer`: Stores data while in buffered mode. The default behavior + is to simply call + :meth:`~._save_to_resource`. + + **Thread safety** + + Whether or not buffering is thread safe depends on the buffering method used. In + general, both the buffering logic and the data type operations must be + thread safe for the resulting collection type to be thread safe. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffered = _CounterFuncContext(self._flush) + + @classmethod + def __init_subclass__(cls): + """Register subclasses for the purpose of global buffering. + + Each subclass has its own means of buffering and must be flushed. + """ + super().__init_subclass__() + if not isabstract(cls): + cls._buffer_context = _CounterFuncContext(cls._flush_buffer) + + @classmethod + def buffer_backend(cls, *args, **kwargs): + """Enter context to buffer all operations for this backend.""" + return cls._buffer_context + + @classmethod + def backend_is_buffered(cls): + """Check if this backend is currently buffered.""" + return bool(cls._buffer_context) + + def _save(self): + """Synchronize data with the backend but buffer if needed. + + This method is identical to the SyncedCollection implementation for + `sync` except that it determines whether data is actually synchronized + or instead written to a temporary buffer based on the buffering mode. + """ + if not self._suspend_sync: + if self._root is None: + if self._is_buffered: + self._save_to_buffer() + else: + self._save_to_resource() + else: + self._root._save() + + def _load(self): + """Load data from the backend but buffer if needed. + + This method is identical to the :class:`~.SyncedCollection` + implementation except that it determines whether data is actually + synchronized or instead read from a temporary buffer based on the + buffering mode. + """ + if not self._suspend_sync: + if self._root is None: + if self._is_buffered: + data = self._load_from_buffer() + else: + data = self._load_from_resource() + with self._suspend_sync: + self._update(data) + else: + self._root._load() + + def _save_to_buffer(self): + """Store data in buffer. + + By default, this method simply calls :meth:`~._save_to_resource`. Subclasses + must implement specific buffering strategies. + """ + self._save_to_resource() + + def _load_from_buffer(self): + """Store data in buffer. + + By default, this method simply calls :meth:`~._load_from_resource`. Subclasses + must implement specific buffering strategies. + + Returns + ------- + Collection + An equivalent unsynced collection satisfying + :meth:`~signac.synced_collections.data_types.synced_collection.SyncedCollection.is_base_type` that + contains the buffered data. By default, the buffered data is just the + data in the resource. + + """ # noqa: E501 + self._load_from_resource() + + @property + def _is_buffered(self): + """Check if we should write to the buffer or not.""" + return self.buffered or type(self)._buffer_context + + def _flush(self): + """Flush data associated with this instance from the buffer.""" + pass + + @classmethod + def _flush_buffer(self): + """Flush all data in this class's buffer.""" + pass diff --git a/signac/synced_collections/buffers/file_buffered_collection.py b/signac/synced_collections/buffers/file_buffered_collection.py new file mode 100644 index 000000000..e7519f0be --- /dev/null +++ b/signac/synced_collections/buffers/file_buffered_collection.py @@ -0,0 +1,372 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""A standardized buffering implementation for file-based backends. + +All file-based backends can use a similar buffering protocol. In particular, +integrity checks can be performed by checking for whether the file has been +modified since it was originally loaded into the buffer. However, various +specific components are abstract and must be implemented by child classes. +""" + +import errno +import os +import warnings +from abc import abstractmethod +from threading import RLock +from typing import Dict, Tuple, Union + +from ..data_types.synced_collection import _LoadAndSave +from ..errors import BufferedError, MetadataError +from ..utils import _CounterFuncContext, _NullContext +from .buffered_collection import BufferedCollection + + +class _FileBufferedContext(_CounterFuncContext): + """Extend the usual buffering context to support setting the buffer size. + + This context allows the buffer_backend method to temporarily set the buffer + size within the scope of this context. + """ + + def __init__(self, cls): + super().__init__(cls._flush_buffer) + self._buffer_capacity = None + self._original_buffer_capacitys = [] + self._cls = cls + + def __call__(self, buffer_capacity=None): + self._buffer_capacity = buffer_capacity + return self + + def __enter__(self): + super().__enter__() + if self._buffer_capacity is not None: + self._original_buffer_capacitys.append(self._cls.get_buffer_capacity()) + self._cls.set_buffer_capacity(self._buffer_capacity) + else: + self._original_buffer_capacitys.append(None) + self._buffer_capacity = None + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + original_buffer_capacity = self._original_buffer_capacitys.pop() + if original_buffer_capacity is not None: + self._cls.set_buffer_capacity(original_buffer_capacity) + + +class _BufferedLoadAndSave(_LoadAndSave): + """Wrap base loading and saving with an extra thread lock. + + Writes to buffered collections will also modify the buffer, so they must + acquire the buffer lock in addition to the default behavior. + """ + + def __enter__(self): + self._collection._buffer_lock.__enter__() + super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + super().__exit__(exc_type, exc_val, exc_tb) + finally: + self._collection._buffer_lock.__exit__(exc_type, exc_val, exc_tb) + + +class FileBufferedCollection(BufferedCollection): + """A :class:`~.SyncedCollection` that can buffer file I/O. + + This class provides a standardized buffering protocol for all file-based + backends. All file-based backends can use the same set of integrity checks + prior to a buffer flush to ensure that no conflicting modifications are + made. Specifically, they can check whether the file has been modified on + disk since it was originally loaded to the buffer. This class provides the + basic infrastructure for that and defines standard methods that can be used + by all classes. Subclasses must define the appropriate storage mechanism. + + .. note:: + Important notes for developers: + - This class should be inherited before any other collections. This + requirement is due to the extensive use of multiple inheritance. + Since this class is designed to be combined with other + :class:`~.SyncedCollection` types without making those types aware + of buffering behavior, it transparently hooks into the + initialization process, but this is dependent on its constructor + being called before those of other classes. + - All subclasses must define a class level ``_BUFFER_CAPACITY`` + variable that is used to determine the maximum allowable buffer + size. + + Parameters + ---------- + filename : str, optional + The filename of the associated JSON file on disk (Default value = None). + + Warnings + -------- + Although it can be done safely, in general modifying two different collections + pointing to the same underlying resource while both are in different buffering + modes is unsupported and can lead to undefined behavior. This class makes a + best effort at performing safe modifications, but it is possible to construct + nested buffered contexts for different objects that can lead to an invalid + buffer state, or even situations where there is no obvious indicator of what + is the canonical source of truth. In general, if you need multiple objects + pointing to the same resource, it is **strongly** recommeneded to work with + both of them in identical buffering states at all times. + + """ + + _LoadSaveType = _BufferedLoadAndSave + + def __init__(self, parent=None, filename=None, *args, **kwargs): + super().__init__(parent=parent, filename=filename, *args, **kwargs) + self._filename = filename + + @classmethod + def __init_subclass__(cls): + """Prepare subclasses.""" + super().__init_subclass__() + cls._CURRENT_BUFFER_SIZE = 0 + + # This dict is the actual data buffer, mapping filenames to their + # cached data and metadata. + cls._buffer: Dict[str, Dict[str, Union[bytes, str, Tuple[int, float]]]] = {} + + # Buffered contexts may be nested, and when leaving a buffered context + # we only want to flush collections that are no longer buffered. To + # accomplish this, we maintain a list of buffered collections so that + # we can perform per-instance flushes that account for their current + # buffering state. + cls._buffered_collections: Dict[int, BufferedCollection] = {} + + cls._buffer_context = _FileBufferedContext(cls) + + @classmethod + def enable_multithreading(cls): + """Enable safety checks and thread locks required for thread safety. + + This method adds managed buffer-related thread safety in addition to + what the parent method does. + + """ + super().enable_multithreading() + cls._BUFFER_LOCK = RLock() + + @classmethod + def disable_multithreading(cls): + """Disable all safety checks and thread locks required for thread safety. + + This method adds managed buffer-related thread safety in addition to + what the parent method does. + + """ + super().disable_multithreading() + cls._BUFFER_LOCK = _NullContext() + + @property + def _buffer_lock(self): + """Acquire the buffer lock.""" + return type(self)._BUFFER_LOCK + + def _get_file_metadata(self): + """Return metadata of file. + + Returns + ------- + Tuple[int, float] or None + The size and last modification time of the associated file. If the + file does not exist, returns :code:`None`. + + """ + try: + metadata = os.stat(self._filename) + return metadata.st_size, metadata.st_mtime_ns + except OSError as error: + if error.errno != errno.ENOENT: + raise + # A return value of None indicates that the file does not + # exist. Since any non-``None`` value will return `False` when + # compared to ``None``, returning ``None`` provides a + # reasonable value to compare against for metadata-based + # validation checks. + return None + + @classmethod + def get_buffer_capacity(cls): + """Get the current buffer capacity. + + Returns + ------- + int + The amount of data that can be stored before a flush is triggered + in the appropriate units for a particular buffering implementation. + + """ + return cls._BUFFER_CAPACITY + + @classmethod + def set_buffer_capacity(cls, new_capacity): + """Update the buffer capacity. + + Parameters + ---------- + new_capacity : int + The new capacity of the buffer in the appropriate units for a particular + buffering implementation. + + """ + cls._BUFFER_CAPACITY = new_capacity + with cls._BUFFER_LOCK: + if new_capacity < cls._CURRENT_BUFFER_SIZE: + cls._flush_buffer(force=True) + + @classmethod + def get_current_buffer_size(cls): + """Get the total amount of data currently stored in the buffer. + + Returns + ------- + int + The size of all data contained in the buffer in the appropriate + units for a particular buffering implementation. + + """ + return cls._CURRENT_BUFFER_SIZE + + def _load_from_buffer(self): + """Read data from buffer. + + See :meth:`~._initialize_data_in_buffer` for details on the data stored + in the buffer and the integrity checks performed. + + Returns + ------- + collections.abc.Collection + A collection of the same base type as the :class:`~.SyncedCollection` this + method is called for, corresponding to data loaded from the + underlying file. + + """ + with self._buffer_lock: + if self._filename not in type(self)._buffer: + # The first time this method is called, if nothing is in the buffer + # for this file then we cannot guarantee that the _data attribute + # is valid either since the resource could have been modified + # between when _data was last updated and when this load is being + # called. As a result, we have to load from the resource here to be + # safe. + data = self._load_from_resource() + with self._thread_lock, self._suspend_sync: + self._update(data) + self._initialize_data_in_buffer() + + # This storage can be safely updated every time on every thread. + type(self)._buffered_collections[id(self)] = self + + @abstractmethod + def _initialize_data_in_buffer(self): + """Create the initial entry for the data in the cache. + + This method should be called the first time that a collection's data is + accessed in buffered mode. This method stores the encoded data in the + cache, along with the metadata of the underlying file and any other + information that may be used for validation later. This information + depends on the implementation of the buffer in subclasses. + """ + pass + + @classmethod + def _flush_buffer(cls, force=False, retain_in_force=False): + """Flush the data in the file buffer. + + Parameters + ---------- + force : bool + If True, force a flush even in buffered mode (defaults to False). This + parameter is used when the buffer is filled to capacity. + retain_in_force : bool + If True, when forcing a flush a collection is retained in the buffer. + This feature is useful if only some subset of the buffer's contents + are relevant to size restrictions. For intance, since only modified + items will have to be written back out to disk, a buffer protocol may + not care to count unmodified collections towards the total. + + Returns + ------- + dict + Mapping of filename and errors occured during flushing data. + + Raises + ------ + BufferedError + If there are any issues with flushing the data. + + """ + issues = {} + + # We need to use the list of buffered objects rather than directly + # looping over the local cache so that each collection can + # independently decide whether or not to flush based on whether it's + # still buffered (if buffered contexts are nested). + remaining_collections = {} + while True: + # This is the only part that needs to be locked; once items are + # removed from the buffer they can be safely handled on separate + # threads. + with cls._BUFFER_LOCK: + try: + ( + col_id, + collection, + ) = cls._buffered_collections.popitem() + except KeyError: + break + + if collection._is_buffered and not force: + # If force is false, then the only way for the collection to + # still be buffered is if there are nested buffered contexts. + # In that case, flush_buffer was called due to the exit of an + # inner buffered context, and we shouldn't do anything with + # this object, so we just put it back in the list *and* skip + # the flush. + remaining_collections[col_id] = collection + continue + elif force and retain_in_force: + # If force is true, the collection must still be buffered. + # In that case, the retain_in_force parameter controls whether + # we we want to put it back in the remaining_collections list + # after flushing any writes. + remaining_collections[col_id] = collection + + try: + collection._flush(force=force) + except (OSError, MetadataError) as err: + issues[collection._filename] = err + if not issues: + cls._buffered_collections = remaining_collections + else: + raise BufferedError(issues) + + # TODO: The buffer_size argument should be changed to buffer_capacity in + # signac 2.0 for consistency with the new names in synced collections. + @classmethod + def buffer_backend(cls, buffer_size=None, force_write=None, *args, **kwargs): + """Enter context to buffer all operations for this backend. + + Parameters + ---------- + buffer_size : int + The capacity of the buffer to use within this context (resets after + the context is exited). + force_write : bool + This argument does nothing and is only present for compatibility + with signac 1.x. + """ + if force_write is not None: + warnings.warn( + DeprecationWarning( + "The force_write parameter is deprecated and will be removed in " + "signac 2.0. This functionality is no longer supported." + ) + ) + return cls._buffer_context(buffer_size) diff --git a/signac/synced_collections/buffers/memory_buffered_collection.py b/signac/synced_collections/buffers/memory_buffered_collection.py new file mode 100644 index 000000000..16ca88e4d --- /dev/null +++ b/signac/synced_collections/buffers/memory_buffered_collection.py @@ -0,0 +1,278 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""A standardized buffering implementation for file-based backends. + +The buffering method implemented here involves a single buffer of references to +in-memory objects containing data. These objects are the base types of a given +:class:`~.SyncedCollection` type, e.g. a dict for all dict-like collections, +and are the underlying data stores for those types. This buffering method +exploits the fact that all mutable collection types in Python are references, +so modifying one such collection results in modifying all of them, thereby +removing any need for more complicated synchronization protocols. +""" + +from ..errors import MetadataError +from .file_buffered_collection import FileBufferedCollection + + +class SharedMemoryFileBufferedCollection(FileBufferedCollection): + """A :class:`~.SyncedCollection` that defers all I/O when buffered. + + This class extends the :class:`~.FileBufferedCollection` and implements a + concrete storage mechanism in which collections store a reference to their + data in a buffer. This method takes advantage of the reference-based semantics + of built-in Python mutable data types like dicts and lists. All collections + referencing the same file are pointed to the same underlying data store in + buffered mode, allowing all changes in one to be transparently reflected in + the others. To further improve performance, the buffer size is determined + only based on the number of modified collections stored, not the total number. + As a result, the maximum capacity is only reached when a large number of + modified collections are stored, and unmodified collections are only removed + from the buffer when a buffered context is exited (rather than when buffer + capacity is exhausted). See the Warnings section for more information. + + The buffer size and capacity for this class is measured in the total number + of collections stored in the buffer that have undergone any modifications + since their initial load from disk. A sequence of read-only operations will + load data into the buffer, but the apparent buffer size will be zero. + + .. note:: + Important note for subclasses: This class should be inherited before + any other collections. This requirement is due to the extensive use of + multiple inheritance: since this class is designed to be combined with + other :class:`~.SyncedCollection` types without making those types aware + of buffering behavior, it transparently hooks into the initialization + process, but this is dependent on its constructor being called before + those of other classes. + + **Thread safety** + + This buffering method is thread safe. This thread safety is independent of the + safety of an individual collection backend; the backend must support thread + safe writes to the underlying resource in order for a buffered version using + this class to be thread safe for general use. The thread safety guaranteed + by this class only concerns buffer reads, writes, and flushes. All these + operations are serialized because there is no way to prevent one collection + from triggering a flush while another still thinks its data is in the cache; + however, this shouldn't be terribly performance-limiting since in buffered + mode we're avoiding I/O anyway and that's the only thing that can be effectively + parallelized here. + + Parameters + ---------- + filename : str, optional + The filename of the associated JSON file on disk (Default value = None). + + Warnings + -------- + - Although it can be done safely, in general modifying two different collections + pointing to the same underlying resource while both are in different buffering + modes is unsupported and can lead to undefined behavior. This class makes a + best effort at performing safe modifications, but it is possible to construct + nested buffered contexts for different objects that can lead to an invalid + buffer state, or even situations where there is no obvious indicator of what + is the canonical source of truth. In general, if you need multiple objects + pointing to the same resource, it is **strongly** recommeneded to work with + both of them in identical buffering states at all times. + - This buffering method has no upper bound on the buffer size if all + operations on buffered objects are read-only operations. If a strict upper bound + is required, for instance due to strict virtual memory limits on a given system, + use of the :class:`~.SerializedFileBufferedCollection` will allow limiting + the total memory usage of the process. + + """ + + _BUFFER_CAPACITY = 1000 # The number of collections to store in the buffer. + + def _flush(self, force=False): + """Save buffered changes to the underlying file. + + Parameters + ---------- + force : bool + If True, force a flush even in buffered mode (defaults to False). This + parameter is used when the buffer is filled to capacity. + + Raises + ------ + MetadataError + If any file is detected to have changed on disk since it was + originally loaded into the buffer and modified. + + """ + # Different files in the buffer can be safely flushed simultaneously, + # but a given file can only be flushed on one thread at once. + if not self._is_buffered or force: + try: + cached_data = type(self)._buffer[self._filename] + except KeyError: + # If we got to this point, it means that another collection + # pointing to the same underlying resource flushed the buffer. + # If so, then the data in this instance is still pointing to + # that object's data store. If this was a force flush, then + # the data store is still the cached data, so we're fine. If + # this wasn't a force flush, then we have to reload this + # object's data so that it will stop sharing data with the + # other instance. + if not force: + self._data.clear() + self._update(self._load_from_resource()) + else: + # If the contents have not been changed since the initial read, + # we don't need to rewrite it. + try: + # Validate that the file hasn't been changed by + # something else. + if cached_data["modified"]: + if cached_data["metadata"] != self._get_file_metadata(): + raise MetadataError(self._filename, cached_data["contents"]) + self._save_to_resource() + finally: + # Whether or not an error was raised, the cache must be + # cleared to ensure a valid final buffer state, unless + # we're force flushing in which case we never delete, but + # take note that the data is no longer modified relative to + # its representation on disk. + if cached_data["modified"]: + type(self)._CURRENT_BUFFER_SIZE -= 1 + if not force: + del type(self)._buffer[self._filename] + else: + # Have to update the metadata on a force flush because + # we could modify this item again later, leading to + # another (possibly forced) flush afterwards that will + # appear invalid if the metadata isn't updated to the + # metadata after the current flush. + cached_data["metadata"] = self._get_file_metadata() + cached_data["modified"] = False + else: + # If this object is still buffered _and_ this wasn't a force flush, + # that implies a nesting of buffered contexts in which another + # collection pointing to the same data flushed the buffer. This + # object's data will still be pointing to that one, though, so the + # safest choice is to reinitialize its data from scratch. + with self._suspend_sync: + self._data = { + key: self._from_base(data=value, parent=self) + for key, value in self._to_base().items() + } + + def _load(self): + """Load data from the backend but buffer if needed. + + Override the base buffered method to skip the _update and to let + _load_from_buffer happen "in place." + """ + if not self._suspend_sync: + if self._root is None: + if self._is_buffered: + self._load_from_buffer() + else: + data = self._load_from_resource() + with self._suspend_sync: + self._update(data) + else: + self._root._load() + + def _save_to_buffer(self): + """Store data in buffer. + + See :meth:`~._initialize_data_in_buffer` for details on the data stored + in the buffer and the integrity checks performed. + """ + type(self)._buffered_collections[id(self)] = self + + # Since one object could write to the buffer and trigger a flush while + # another object was found in the buffer and attempts to proceed + # normally, we have to serialize this whole block. In theory we might + # be safe without it because the only operations that should reach this + # point without already being locked are destructive operations (clear, + # reset) that don't use the :meth:`_load_and_save` context, and for + # those the writes will be automatically serialized because Python + # dicts are thread-safe because of the GIL. However, it's best not to + # depend on the thread-safety of built-in containers. + with self._buffer_lock: + if self._filename in type(self)._buffer: + # Always track all instances pointing to the same data. + + # If all we had to do is set the flag, it could be done without any + # check, but we also need to increment the number of modified + # items, so we may as well do the update conditionally as well. + if not type(self)._buffer[self._filename]["modified"]: + type(self)._buffer[self._filename]["modified"] = True + type(self)._CURRENT_BUFFER_SIZE += 1 + else: + self._initialize_data_in_buffer(modified=True) + type(self)._CURRENT_BUFFER_SIZE += 1 + + if type(self)._CURRENT_BUFFER_SIZE > type(self)._BUFFER_CAPACITY: + type(self)._flush_buffer(force=True) + + def _load_from_buffer(self): + """Read data from buffer. + + See :meth:`~._initialize_data_in_buffer` for details on the data stored + in the buffer and the integrity checks performed. + + Returns + ------- + Collection + A collection of the same base type as the SyncedCollection this + method is called for, corresponding to data loaded from the + underlying file. + + """ + super()._load_from_buffer() + + # Set local data to the version in the buffer. + self._data = type(self)._buffer[self._filename]["contents"] + + def _initialize_data_in_buffer(self, modified=False): + """Create the initial entry for the data in the cache. + + Stores the following information: + - The metadata provided by :meth:`~._get_file_metadata`. Used to + check if a file has been modified on disk since it was loaded + into the buffer. + - A flag indicating whether any operation that saves to the buffer + has occurred, e.g. a ``__setitem__`` call. This flag is used to + determine what collections need to be saved to disk when + flushing. + + Parameters + ---------- + modified : bool + Whether or not the data has been modified from the version on disk + (Default value = False). + + """ + metadata = self._get_file_metadata() + type(self)._buffer[self._filename] = { + "contents": self._data, + "metadata": metadata, + "modified": modified, + } + + @classmethod + def _flush_buffer(cls, force=False): + """Flush the data in the file buffer. + + Parameters + ---------- + force : bool + If True, force a flush even in buffered mode (defaults to False). This + parameter is used when the buffer is filled to capacity. + + Returns + ------- + dict + Mapping from filenames to errors that occured while flushing data. + + Raises + ------ + BufferedError + If there are any issues with flushing the data. + + """ + return super()._flush_buffer(force=force, retain_in_force=True) diff --git a/signac/synced_collections/buffers/serialized_file_buffered_collection.py b/signac/synced_collections/buffers/serialized_file_buffered_collection.py new file mode 100644 index 000000000..daabf61a5 --- /dev/null +++ b/signac/synced_collections/buffers/serialized_file_buffered_collection.py @@ -0,0 +1,309 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Buffering for file-based backends using a serialized buffer. + +The buffering method implemented here involves a single buffer of serialized +data. All collections in buffered mode encode their data into this buffer on save +and decode from it on load. +""" + +import hashlib +import json + +from ..errors import MetadataError +from ..utils import SyncedCollectionJSONEncoder +from .file_buffered_collection import FileBufferedCollection + + +class SerializedFileBufferedCollection(FileBufferedCollection): + """A :class:`~.FileBufferedCollection` based on a serialized data store. + + This class extends the :class:`~.FileBufferedCollection` and implements a + concrete storage mechanism in which data is encoded (by default, into JSON) + and stored into a buffer. This buffer functions as a central data store for + all collections and is a synchronization point for various collections + pointing to the same underlying file. This serialization method may be a + bottleneck in some applications; see the Warnings section for more information. + + The buffer size and capacity for this class is measured in the total number + of bytes stored in the buffer that correspond to file data. This is *not* + the total size of the buffer, which also contains additional information + like the hash of the data and the file metadata (which are used for + integrity checks), but it is the relevant metric for users. + + .. note:: + Important note for subclasses: This class should be inherited before + any other collections. This requirement is due to the extensive use of + multiple inheritance: since this class is designed to be combined with + other :class:`~.SyncedCollection` types without making those types aware + of buffering behavior, it transparently hooks into the initialization + process, but this is dependent on its constructor being called before + those of other classes. + + **Thread safety** + + This buffering method is thread safe. This thread safety is independent of the + safety of an individual collection backend; the backend must support thread + safe writes to the underlying resource in order for a buffered version using + this class to be thread safe for general use. The thread safety guaranteed + by this class only concerns buffer reads, writes, and flushes. All these + operations are serialized because there is no way to prevent one collection + from triggering a flush while another still thinks its data is in the cache. + + Parameters + ---------- + filename : str, optional + The filename of the associated JSON file on disk (Default value = None). + + Warnings + -------- + - Although it can be done safely, in general modifying two different collections + pointing to the same underlying resource while both are in different buffering + modes is unsupported and can lead to undefined behavior. This class makes a + best effort at performing safe modifications, but it is possible to construct + nested buffered contexts for different objects that can lead to an invalid + buffer state, or even situations where there is no obvious indicator of what + is the canonical source of truth. In general, if you need multiple objects + pointing to the same resource, it is **strongly** recommeneded to work with + both of them in identical buffering states at all times. + - The overhead of this buffering method is quite high due to the constant + encoding and decoding of data. For performance-critical applications where + memory is not highly constrained and virtual memory limits are absent, the + :class:`~.SharedMemoryFileBufferedCollection` may be more appropriate. + - Due to the possibility of read operations triggering a flush, the + contents of the buffer may be invalidated on loads as well. To prevent this + even nominally read-only operations are serialized. As a result, although + this class is thread safe, it will effectively serialize all operations and + will therefore not be performant. + + """ + + _BUFFER_CAPACITY = 32 * 2 ** 20 # 32 MB + + def _flush(self, force=False): + """Save buffered changes to the underlying file. + + Parameters + ---------- + force : bool + If True, force a flush even in buffered mode (defaults to False). This + parameter is used when the buffer is filled to capacity. + + Raises + ------ + MetadataError + If any file is detected to have changed on disk since it was + originally loaded into the buffer and modified. + + """ + # Different files in the buffer can be safely flushed simultaneously, + # but a given file can only be flushed on one thread at once. + with self._buffer_lock: + if not self._is_buffered or force: + try: + cached_data = type(self)._buffer[self._filename] + except KeyError: + # There are valid reasons for nothing to be in the cache (the + # object was never actually accessed during global buffering, + # multiple collections pointing to the same file, etc). + return + else: + blob = self._encode(self._data) + + # If the contents have not been changed since the initial read, + # we don't need to rewrite it. + try: + if self._hash(blob) != cached_data["hash"]: + # Validate that the file hasn't been changed by + # something else. + if cached_data["metadata"] != self._get_file_metadata(): + raise MetadataError( + self._filename, cached_data["contents"] + ) + self._update(self._decode(cached_data["contents"])) + self._save_to_resource() + finally: + # Whether or not an error was raised, the cache must be + # cleared to ensure a valid final buffer state. + del type(self)._buffer[self._filename] + data_size = len(cached_data["contents"]) + type(self)._CURRENT_BUFFER_SIZE -= data_size + + @staticmethod + def _hash(blob): + """Calculate and return the md5 hash value for the file data. + + Parameters + ---------- + blob : bytes + Byte literal to be hashed. + + Returns + ------- + str + The md5 hash of the input bytes. + + """ + if blob is not None: + m = hashlib.md5() + m.update(blob) + return m.hexdigest() + + @staticmethod + def _encode(data): + """Encode the data into a serializable form. + + This method assumes JSON-serializable data, but subclasses can override + this hook method to change the encoding behavior as needed. + + Parameters + ---------- + data : collections.abc.Collection + Any collection type that can be encoded. + + Returns + ------- + bytes + The underlying encoded data. + + """ + return json.dumps(data, cls=SyncedCollectionJSONEncoder).encode() + + @staticmethod + def _decode(blob): + """Decode serialized data. + + This method assumes JSON-serializable data, but subclasses can override + this hook method to change the encoding behavior as needed. + + Parameters + ---------- + blob : bytes + Byte literal to be decoded. + + Returns + ------- + data : collections.abc.Collection + The decoded data in the appropriate base collection type. + + """ + return json.loads(blob.decode()) + + def _save_to_buffer(self): + """Store data in buffer. + + See :meth:`~._initialize_data_in_buffer` for details on the data stored + in the buffer and the integrity checks performed. + """ + type(self)._buffered_collections[id(self)] = self + + # Since one object could write to the buffer and trigger a flush while + # another object was found in the buffer and attempts to proceed + # normally, we have to serialize this whole block. In theory we might + # be safe without it because the only operations that should reach this + # point without already being locked are destructive operations (clear, + # reset) that don't use the :meth:`_load_and_save` context, and for + # those the writes will be automatically serialized because Python + # dicts are thread-safe because of the GIL. However, it's best not to + # depend on the thread-safety of built-in containers. + with self._buffer_lock: + if self._filename in type(self)._buffer: + # Always track all instances pointing to the same data. + blob = self._encode(self._data) + cached_data = type(self)._buffer[self._filename] + buffer_size_change = len(blob) - len(cached_data["contents"]) + type(self)._CURRENT_BUFFER_SIZE += buffer_size_change + cached_data["contents"] = blob + else: + # The only methods that could safely call sync without a load are + # destructive operations like `reset` or `clear` that completely + # wipe out previously existing data. Therefore, the safest choice + # for ensuring consistency of the buffer is to modify the stored + # hash (which is used for the consistency check) with the hash of + # the current data on disk. _initialize_data_in_buffer always uses + # the current metadata, so the only extra work here is to modify + # the hash after it's called (since it uses self._to_base() to get + # the data to initialize the cache with). + self._initialize_data_in_buffer() + disk_data = self._load_from_resource() + type(self)._buffer[self._filename]["hash"] = self._hash( + self._encode(disk_data) + ) + + if type(self)._CURRENT_BUFFER_SIZE > type(self)._BUFFER_CAPACITY: + type(self)._flush_buffer(force=True) + + def _load_from_buffer(self): + """Read data from buffer. + + See :meth:`~._initialize_data_in_buffer` for details on the data stored + in the buffer and the integrity checks performed. + + Returns + ------- + Collection + A collection of the same base type as the SyncedCollection this + method is called for, corresponding to data loaded from the + underlying file. + + """ + with self._buffer_lock: + super()._load_from_buffer() + + # Load from buffer. This has to happen inside the locked context + # because otherwise the data could be flushed from the buffer by + # another thread. + blob = type(self)._buffer[self._filename]["contents"] + + if type(self)._CURRENT_BUFFER_SIZE > type(self)._BUFFER_CAPACITY: + type(self)._flush_buffer(force=True) + return self._decode(blob) + + def _initialize_data_in_buffer(self): + """Create the initial entry for the data in the cache. + + Stores the following information: + - The hash of the data as initially stored in the cache. This hash + is used to quickly determine whether data has changed when flushing. + - The metadata provided by :meth:`~._get_file_metadata`. Used to + check if a file has been modified on disk since it was loaded + into the buffer. + + This method also increments the current buffer size, which in this class + is the total number of bytes of data in the buffer. + """ + blob = self._encode(self._data) + metadata = self._get_file_metadata() + + type(self)._buffer[self._filename] = { + "contents": blob, + "hash": self._hash(blob), + "metadata": metadata, + } + type(self)._CURRENT_BUFFER_SIZE += len( + type(self)._buffer[self._filename]["contents"] + ) + + @classmethod + def _flush_buffer(cls, force=False): + """Flush the data in the file buffer. + + Parameters + ---------- + force : bool + If True, force a flush even in buffered mode (defaults to False). This + parameter is used when the buffer is filled to capacity. + + Returns + ------- + dict + Mapping of filename and errors occured during flushing data. + + Raises + ------ + BufferedError + If there are any issues with flushing the data. + + """ + return super()._flush_buffer(force=force, retain_in_force=False) diff --git a/signac/synced_collections/data_types/__init__.py b/signac/synced_collections/data_types/__init__.py new file mode 100644 index 000000000..9f7bf6422 --- /dev/null +++ b/signac/synced_collections/data_types/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""This subpackage defines various synced data types.""" + +from .synced_collection import SyncedCollection +from .synced_dict import SyncedDict +from .synced_list import SyncedList + +__all__ = ["SyncedCollection", "SyncedDict", "SyncedList"] diff --git a/signac/synced_collections/data_types/attr_dict.py b/signac/synced_collections/data_types/attr_dict.py new file mode 100644 index 000000000..cc4be826b --- /dev/null +++ b/signac/synced_collections/data_types/attr_dict.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implements the :class:`AttrDict`. + +This simple mixin class implements overloads for __setattr__, __getattr__, and +__delattr__. While we do not want to offer this API generally for all +SyncedDict objects, some applications may want to add this feature, so this +simple mixin can be combined via inheritance without causing much difficulty. +""" + +from typing import FrozenSet + + +class AttrDict: + """A class that redirects attribute access methods to __getitem__. + + Although this class is called an :class:`AttrDict`, it does not directly + inherit from any dict-like class or offer any relevant APIs. Its only purpose + is to be used as a mixin with other dict-like classes to add attribute-based + access to dictionary contents. + + Subclasses that inherit from this class must define the ``_PROTECTED_KEYS`` + class variable, which indicates known attributes of the object. This indication + is necessary because otherwise accessing ``obj.data`` is ambiguous as to + whether it is a reference to a special ``data`` attribute or whether it is + equivalent to ``obj['data']``. Without this variable, a user could mask + internal variables inaccessible via normal attribute access by adding dictionary + keys with the same name. + + Examples + -------- + >>> assert dictionary['foo'] == dictionary.foo + + """ + + _PROTECTED_KEYS: FrozenSet[str] = frozenset() + + def __getattr__(self, name): + if name.startswith("__"): + raise AttributeError(f"{type(self)} has no attribute '{name}'") + try: + return self.__getitem__(name) + except KeyError as e: + raise AttributeError(e) + + def __setattr__(self, key, value): + # This logic assumes that __setitem__ will not be called until after + # the object has been fully instantiated. We may want to add a try + # except in the else clause in case someone subclasses these and tries + # to use d['foo'] inside a constructor prior to _data being defined. + # The order of these checks assumes that setting protected keys will be + # much more common than setting dunder attributes. + if key in self._PROTECTED_KEYS or key.startswith("__"): + super().__setattr__(key, value) + else: + self.__setitem__(key, value) + + def __delattr__(self, key): + # The order of these checks assumes that deleting protected keys will be + # much more common than deleting dunder attributes. + if key in self._PROTECTED_KEYS or key.startswith("__"): + super().__delattr__(key) + else: + self.__delitem__(key) diff --git a/signac/synced_collections/data_types/synced_collection.py b/signac/synced_collections/data_types/synced_collection.py new file mode 100644 index 000000000..40a4290fd --- /dev/null +++ b/signac/synced_collections/data_types/synced_collection.py @@ -0,0 +1,511 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implement the SyncedCollection class.""" +from abc import abstractmethod +from collections import defaultdict +from collections.abc import Collection +from inspect import isabstract +from threading import RLock +from typing import Any, DefaultDict, List + +from ..numpy_utils import _convert_numpy +from ..utils import AbstractTypeResolver, _CounterContext, _NullContext + +# Identifies types of SyncedCollection, which are the base type for this class. +_sc_resolver = AbstractTypeResolver( + { + "SYNCEDCOLLECTION": lambda obj: isinstance(obj, SyncedCollection), + } +) + +_collection_resolver = AbstractTypeResolver( + { + "COLLECTION": lambda obj: isinstance(obj, Collection), + } +) + + +class _LoadAndSave: + """A context manager for :class:`SyncedCollection` to wrap saving and loading. + + Any write operation on a synced collection must be preceded by a load and + followed by a save. Moreover, additional logic may be required to handle + other aspects of the synchronization, particularly the acquisition of thread + locks. This class abstracts this concept, making it easy for subclasses to + customize the behavior if needed (for instance, to introduce additional locks). + """ + + def __init__(self, collection): + self._collection = collection + + def __enter__(self): + self._collection._thread_lock.__enter__() + self._collection._load() + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self._collection._save() + finally: + self._collection._thread_lock.__exit__(exc_type, exc_val, exc_tb) + + +class SyncedCollection(Collection): + """An abstract :class:`~collections.abc.Collection` type that is synced with a backend. + + This class extends :py:class:`collections.abc.Collection` and adds a number of abstract + internal methods that must be implemented by its subclasses. These methods can be + split into two groups of functions that are designed to be implemented by + separate subtrees in the inheritance hierarchy that can then be composed: + + **Concrete Collection Types** + + These subclasses should implement the APIs for specific types of + collections. For instance, a list-like :class:`SyncedCollection` + should implement the standard methods for sequences. In addition, they + must implement the following abstract methods defined by the + :class:`SyncedCollection`: + + - :meth:`~.is_base_type`: Determines whether an object satisfies the + semantics of the collection object a given :class:`SyncedCollection` + is designed to mimic. + - :meth:`~._to_base`: Converts a :class:`SyncedCollection` to its + natural base type (e.g. a `list`). + - :meth:`~._update`: Updates the :class:`SyncedCollection` to match the + contents of the provided :py:class:`collections.abc.Collection`. + After calling ``sc._update(c)``, we must have that ``sc == c``; however, + since such updates are frequent when loading and saving data to a + resource, :meth:`_update` should be implemented to minimize new object + creation wherever possible. + + **Backend** + + These subclasses encode the process by which in-memory data is + converted into a representation suitable for a particular backend. For + instance, a JSON backend should know how to save a Python object into a + JSON-encoded file and then read that object back. + + - :meth:`~._load_from_resource`: Loads data from the underlying + resource and returns it in an object satisfying :meth:`~.is_base_type`. + - :meth:`~._save_to_resource`: Stores data to the underlying resource. + - :attr:`~._backend`: A unique string identifier for the resource backend. + + Since these functionalities are effectively completely orthogonal, members of + a given group should be interchangeable. For instance, a dict-like SyncedCollection + can be combined equally easily with JSON, MongoDB, or SQL backends. + + **Validation** + + Due to the restrictions of a particular backend or the needs of a particular + application, synced collections may need to restrict the data that they can + store. Validators provide a standardized mechanism for this. A validator is + a callable that parses any data added to a :class:`SyncedCollection` and + raises an `Exception` if any invalid data is provided. Validators cannot + modify the data and should have no side effects. They are purely provided + as a mechanism to reject invalid data. For example, a JSON validator would + raise Exceptions if it detected non-string keys in a dict. + + Since :class:`SyncedCollection` is designed for extensive usage of + inheritance, validators may be inherited by subclasses. There are two attributes + that subclasses of :class:`SyncedCollection` can define to control the + validators used: + - ``_validators``: A list of callables that will be inherited by all + subclasses. + - ``_all_validators``: A list of callables that will be used to + validate this class, and this class alone. + + When a :class:`SyncedCollection` subclass is initialized (note that this + is at *class* definition time, not when instances are created), its + :meth:`_register_validators` method will be called. If this class defines + an ``_all_validators`` attribute, this set of validators will be used by all + instances of this class. Otherwise, :meth:`_register_validators` will traverse + the MRO and collect the ``_validators`` attributes from all parents of a class, + and store these in the ``_all_validators`` attribute for the class. + + + .. note:: + + Typically, a synced collection will be initialized with resource information, + and data will be pulled from that resource. However, initializing with + both data and resource information is a valid use case. In this case, the + initial data will be validated by the standard validators, however, it + will not be checked against the contents stored in the synced resource and + is assumed to be consistent. This constructor pattern can be useful to + avoid unnecessary resource accesses. + + **Thread safety** + + Whether or not :class:`SyncedCollection` objects are thread-safe depends on the + implementation of the backend. Thread-safety of SyncedCollection objects + is predicated on backends providing an atomic write operation. All concrete + collection types use mutexes to guard against concurrent write operations, + while allowing read operations to happen freely. The validity of this mode + of access depends on the write operations of a SyncedCollection being + atomic, specifically the `:meth:`~._save_to_resource` method. Whether or not + a particular subclass of :class:`SyncedCollection` is thread-safe should be + indicated by that subclass setting the ``_supports_threading`` class variable + to ``True``. This variable is set to ``False`` by :class:`SyncedCollection`, + so subclasses must explicitly opt-in to support threading by setting this + variable to ``True``. + + Backends that support multithreaded execution will have multithreaded + support turned on by default. This support can be enabled or disabled using + the :meth:`enable_multithreading` and :meth:`disable_multithreading` + methods. :meth:`enable_multithreading` will raise a `ValueError` if called + on a class that does not support multithreading. + + + Parameters + ---------- + parent : SyncedCollection, optional + If provided, the collection within which this collection is nested + (Default value = None). + A parent instance of :class:`SyncedCollection` or ``None``. If ``None``, + the collection owns its own data, otherwise it is nested within its + parent. Every :class:`SyncedCollection` either owns its own data, or has + a parent (Default value = None). + + """ + + registry: DefaultDict[str, List[Any]] = defaultdict(list) + # Backends that support threading should modify this flag. + _supports_threading: bool = False + _LoadSaveType = _LoadAndSave + + def __init__(self, parent=None, *args, **kwargs): + # Nested collections need to know their root collection, which is + # responsible for all synchronization, and therefore all the associated + # context managers are also stored from the root. + + if parent is not None: + root = parent._root if parent._root is not None else parent + self._root = root + self._suspend_sync = root._suspend_sync + self._load_and_save = root._load_and_save + else: + self._root = None + self._suspend_sync = _CounterContext() + self._load_and_save = self._LoadSaveType(self) + + if self._supports_threading: + self._locks[self._lock_id] = RLock() + + @classmethod + def _register_validators(cls): + """Register all inherited validators to this class. + + This method is called by __init_subclass__ when subclasses are created + to control what validators will be applied to data added to instances of + that class. By default, the ``_all_validators`` class variable defined + on the class itself determines the validation rules for that class. If + that variable is not defined, then all parents of the class are searched, + and a list of validators is constructed by concatenating the ``_validators`` + class variable for each parent class that defines it. + """ + # Must explicitly look in cls.__dict__ so that the attribute is not + # inherited from a parent class. + if "_all_validators" not in cls.__dict__: + validators = [] + # Classes inherit the validators of their parent classes. + for base_cls in cls.__mro__: + if hasattr(base_cls, "_validators"): + validators.extend( + [v for v in base_cls._validators if v not in validators] + ) + cls._all_validators = validators + + @classmethod + def __init_subclass__(cls): + """Register and enable validation in subclasses. + + All subclasses are given a ``_validators`` list so that separate sets of + validators can be registered to different types of synced collections. Concrete + subclasses (those that have all methods implemented, i.e. that are associated + with both a specific backend and a concrete data type) are also recorded in + an internal registry that is used to convert data from some collection-like + object into a :class:`SyncedCollection`. + """ + # The Python data model promises that __init_subclass__ will be called + # after the class namespace is fully defined, so at this point we know + # whether we have a concrete subclass or not. + if not isabstract(cls): + SyncedCollection.registry[cls._backend].append(cls) + + cls._register_validators() + + # Monkey-patch subclasses that support locking. + if cls._supports_threading: + cls._locks = {} + cls.enable_multithreading() + else: + cls.disable_multithreading() + + @classmethod + def enable_multithreading(cls): + """Enable safety checks and thread locks required for thread safety. + + Support for multithreaded execution can be disabled by calling + :meth:`~.disable_multithreading`; calling this method reverses that. + + """ + if cls._supports_threading: + + @property + def _thread_lock(self): + """Get the lock specific to this collection. + + Since locks support the context manager protocol, this method + can typically be invoked directly as part of a ``with`` statement. + """ + return type(self)._locks[self._lock_id] + + cls._thread_lock = _thread_lock + cls._threading_support_is_active = True + else: + raise ValueError("This class does not support multithreaded execution.") + + @classmethod + def disable_multithreading(cls): + """Disable all safety checks and thread locks required for thread safety. + + The mutex locks required to enable multithreading introduce nontrivial performance + costs, so they can be disabled for classes that support it. + + """ + cls._thread_lock = _NullContext() + cls._threading_support_is_active = False + + @property + def _lock_id(self): + raise NotImplementedError( + "Backends must implement the _lock_id property to support multithreaded " + "execution. This property should return a hashable unique identifier for " + "all collections that will be used to maintain a resource-specific " + "set of locks." + ) + + @property + @abstractmethod + def _backend(self): + """str: The backend associated with a given collection. + + This property is abstract to enforce that subclasses implement it. + Since it's only internal, subclasses can safely override it with just a + raw attribute; this property just serves as a way to enforce the + abstract API for subclasses. + """ + pass + + @classmethod + def _from_base(cls, data, **kwargs): + r"""Dynamically resolve the type of object to the corresponding synced collection. + + This method assumes that ``data`` has already been validated. This assumption + can always be met, since this method should only be called internally by + other methods that modify the internal collection data. While this requirement + does require that all calling methods be responsible for validation, it + confers significant performance benefits because it can instruct any invoked + class constructors not to validate, which is especially important for nested + collections. + + Parameters + ---------- + data : Collection + Data to be converted from base type. + \*\*kwargs + Any keyword arguments to pass to the collection constructor. + + Returns + ------- + Collection + Synced object of corresponding base type. + + Notes + ----- + This method relies on the internal registry of subclasses populated by + :meth:`~.__init_subclass__` and the :meth:`is_base_type` method to + determine the subclass with the appropriate backend and data type. Once + an appropriate type is determined, that class's constructor is called. + Since this method relies on the constructor and other methods, it can + be concretely implemented here rather than requiring subclass + implementations. + + """ + if _collection_resolver.get_type(data) == "COLLECTION": + for base_cls in SyncedCollection.registry[cls._backend]: + if base_cls.is_base_type(data): + return base_cls(data=data, _validate=False, **kwargs) + return _convert_numpy(data) + + @abstractmethod + def _to_base(self): + """Dynamically resolve the synced collection to the corresponding base type. + + This method should not load the data from the underlying resource, it + should simply converts the current in-memory representation of a synced + collection to its naturally corresponding unsynced collection type. + + Returns + ------- + Collection + An equivalent unsynced collection satisfying :meth:`is_base_type`. + + """ + pass + + @classmethod + @abstractmethod + def is_base_type(cls, data): + """Check whether data is of the same base type (such as list or dict) as this class. + + Parameters + ---------- + data : Any + The input data to test. + + Returns + ------- + bool + Whether or not the object can be converted into this synced collection type. + + """ + pass + + @abstractmethod + def _load_from_resource(self): + """Load data from underlying backend. + + This method must be implemented for each backend. Backends may choose + to return ``None``, signaling that no modification should be performed + on the data in memory. This mode is useful for backends where the underlying + resource (e.g. a file) may not initially exist, but can be transparently + created on save. + + Returns + ------- + Collection or None + An equivalent unsynced collection satisfying :meth:`is_base_type` that + contains the data in the underlying resource (e.g. a file). + + """ + pass + + @abstractmethod + def _save_to_resource(self): + """Save data to the backend. + + This method must be implemented for each backend. + """ + pass + + def _save(self): + """Save the data to the backend. + + This method encodes the recursive logic required to handle the saving of + nested collections. For a collection contained within another collection, + only the parent is ever responsible for storing the data. This method + handles the appropriate recursive calls, then farms out the actual writing + to the abstract method :meth:`~._save_to_resource`. + """ + if not self._suspend_sync: + if self._root is None: + self._save_to_resource() + else: + self._root._save() + + @abstractmethod + def _update(self, data): + """Update the in-memory representation to match the provided data. + + The purpose of this method is to update the SyncedCollection to match + the data in the underlying resource. The result of calling this method + should be that ``self == data``. The reason that this method is + necessary is that SyncedCollections can be nested, and nested + collections must also be instances of SyncedCollection so that + synchronization occurs even when nested structures are modified. + Recreating the full nested structure every time data is reloaded from + file is highly inefficient, so this method performs an in-place update + that only changes entries that need to be changed. + + Parameters + ---------- + data : Collection + A collection satisfying :meth:`is_base_type`. + + """ + pass + + def _load(self): + """Load the data from the backend. + + This method encodes the recursive logic required to handle the loading of + nested collections. For a collection contained within another collection, + only the root is ever responsible for loading the data. This method + handles the appropriate recursive calls, then farms out the actual reading + to the abstract method :meth:`~._load_from_resource`. + """ + if not self._suspend_sync: + if self._root is None: + data = self._load_from_resource() + with self._suspend_sync: + self._update(data) + else: + self._root._load() + + def _validate(self, data): + """Validate the input data. + + Parameters + ---------- + data : Collection + An collection satisfying :meth:`is_base_type`. + + """ + for validator in self._all_validators: + validator(data) + + # The following methods share a common implementation for + # all data structures and regardless of backend. + + def __getitem__(self, key): + self._load() + return self._data[key] + + def __delitem__(self, key): + with self._load_and_save: + del self._data[key] + + def __iter__(self): + self._load() + return iter(self._data) + + def __len__(self): + self._load() + return len(self._data) + + def __call__(self): + """Get an equivalent but unsynced object of the base data type. + + Returns + ------- + Collection + An equivalent unsynced collection satisfying :meth:`is_base_type`. + + """ + self._load() + return self._to_base() + + def __eq__(self, other): + self._load() + if isinstance(other, type(self)): + return self() == other() + else: + return self() == other + + def __repr__(self): + self._load() + return repr(self._data) + + def __str__(self): + self._load() + return str(self._data) diff --git a/signac/synced_collections/data_types/synced_dict.py b/signac/synced_collections/data_types/synced_dict.py new file mode 100644 index 000000000..f2c8dae39 --- /dev/null +++ b/signac/synced_collections/data_types/synced_dict.py @@ -0,0 +1,272 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implements the :class:`SyncedDict`. + +This implements a dict-like data structure that also conforms to the +:class:`~.SyncedCollection` API and can be combined with any backend type to +give a dict-like API to a synchronized data structure. +""" + +from collections.abc import Mapping, MutableMapping + +from ..utils import AbstractTypeResolver +from .synced_collection import SyncedCollection, _sc_resolver + +# Identifies mappings, which are the base type for this class. +_mapping_resolver = AbstractTypeResolver( + { + "MAPPING": lambda obj: isinstance(obj, Mapping), + } +) + + +class SyncedDict(SyncedCollection, MutableMapping): + r"""Implement the dict data structure along with values access through attributes named as keys. + + The SyncedDict inherits from :class:`~.SyncedCollection` + and :class:`~collections.abc.MutableMapping`. Therefore, it behaves like a + :class:`dict`. + + Parameters + ---------- + data : Mapping, optional + The initial data to populate the dict. If ``None``, defaults to + ``{}`` (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warning + ------- + While the :class:`SyncedDict` object behaves like a :class:`dict`, + there are important distinctions to remember. In particular, because + operations are reflected as changes to an underlying backend, copying (even + deep copying) a :class:`SyncedDict` instance may exhibit unexpected + behavior. If a true copy is required, you should use the `_to_base()` + method to get a :class:`dict` representation, and if necessary construct a + new :class:`SyncedDict`. + """ + + # The _validate parameter is an optimization for internal use only. This + # argument will be passed by _from_base whenever an unsynced collection is + # being recursively converted, ensuring that validation only happens once. + def __init__(self, data=None, _validate=True, *args, **kwargs): + super().__init__(*args, **kwargs) + if data is None: + self._data = {} + else: + if _validate: + self._validate(data) + with self._suspend_sync: + self._data = { + key: self._from_base(data=value, parent=self) + for key, value in data.items() + } + + def _to_base(self): + """Convert the SyncedDict object to a :class:`dict`. + + Returns + ------- + dict + An equivalent raw :class:`dict`. + + """ + converted = {} + for key, value in self._data.items(): + switch_type = _sc_resolver.get_type(value) + if switch_type == "SYNCEDCOLLECTION": + converted[key] = value._to_base() + else: + converted[key] = value + return converted + + @classmethod + def is_base_type(cls, data): + """Check whether the data is an instance of mapping. + + Parameters + ---------- + data : any + Data to be checked. + + Returns + ------- + bool + + """ + return _mapping_resolver.get_type(data) == "MAPPING" + + def _update(self, data=None, _validate=False): + """Update the in-memory representation to match the provided data. + + The purpose of this method is to update the SyncedCollection to match + the data in the underlying resource. The result of calling this method + should be that ``self == data``. The reason that this method is + necessary is that SyncedCollections can be nested, and nested + collections must also be instances of SyncedCollection so that + synchronization occurs even when nested structures are modified. + Recreating the full nested structure every time data is reloaded from + file is highly inefficient, so this method performs an in-place update + that only changes entries that need to be changed. + + Parameters + ---------- + data : collections.abc.Mapping + The data to be assigned to this dict. If ``None``, the data is left + unchanged (Default value = None). + _validate : bool + If True, the data will not be validated (Default value = False). + + """ + if data is None: + # If no data is passed, take no action. + pass + elif _mapping_resolver.get_type(data) == "MAPPING": + with self._suspend_sync: + for key, new_value in data.items(): + try: + # The most common usage of SyncedCollections is with a + # single object referencing an underlying resource at a + # time, so we should almost always find that elements + # of data are already contained in self._data, so EAFP + # is the best choice for performance. + existing = self._data[key] + except KeyError: + # If the item wasn't present at all, we can simply + # assign it. + if not _validate: + self._validate({key: new_value}) + self._data[key] = self._from_base(new_value, parent=self) + else: + if new_value == existing: + continue + if _sc_resolver.get_type(existing) == "SYNCEDCOLLECTION": + try: + existing._update(new_value) + continue + except ValueError: + pass + + # Fall through if the new value is not identical to the + # existing value and + # 1) The existing value is not a SyncedCollection + # (in which case we would have tried to update it), OR + # 2) The existing value is a SyncedCollection, but + # the new value is not a compatible type for _update. + if not _validate: + self._validate({key: new_value}) + self._data[key] = self._from_base(new_value, parent=self) + + to_remove = [key for key in self._data if key not in data] + for key in to_remove: + del self._data[key] + else: + raise ValueError( + "Unsupported type: {}. The data must be a mapping or None.".format( + type(data) + ) + ) + + def __setitem__(self, key, value): + # TODO: Remove in signac 2.0, currently we're constructing a dict to + # allow in-place modification by _convert_key_to_str, but validators + # should not have side effects once that backwards compatibility layer + # is removed, so we can validate a temporary dict {key: value} and + # directly set using those rather than looping over data. + + data = {key: value} + self._validate(data) + with self._load_and_save, self._suspend_sync: + for key, value in data.items(): + self._data[key] = self._from_base(value, parent=self) + + def reset(self, data): + """Update the instance with new data. + + Parameters + ---------- + data : mapping + Data to update the instance. + + Raises + ------ + ValueError + If the data is not a mapping. + + """ + if _mapping_resolver.get_type(data) == "MAPPING": + self._update(data) + with self._thread_lock: + self._save() + else: + raise ValueError( + "Unsupported type: {}. The data must be a mapping or None.".format( + type(data) + ) + ) + + def keys(self): # noqa: D102 + self._load() + return self._data.keys() + + def values(self): # noqa: D102 + self._load() + return self._to_base().values() + + def items(self): # noqa: D102 + self._load() + return self._to_base().items() + + def get(self, key, default=None): # noqa: D102 + self._load() + return self._data.get(key, default) + + def pop(self, key, default=None): # noqa: D102 + with self._load_and_save: + ret = self._data.pop(key, default) + return ret + + def popitem(self): # noqa: D102 + with self._load_and_save: + ret = self._data.popitem() + return ret + + def clear(self): # noqa: D102 + self._data = {} + with self._thread_lock: + self._save() + + def update(self, other=None, **kwargs): # noqa: D102 + if other is not None: + # Convert sequence of key, value pairs to dict before validation + if _mapping_resolver.get_type(other) != "MAPPING": + other = dict(other) + else: + other = {} + + with self._load_and_save: + # The order here is important to ensure that the promised sequence of + # overrides is obeyed: kwargs > other > existing data. + self._update({**self._data, **other, **kwargs}) + + def setdefault(self, key, default=None): # noqa: D102 + with self._load_and_save: + if key in self._data: + ret = self._data[key] + else: + ret = self._from_base(default, parent=self) + # TODO: Remove in signac 2.0, currently we're constructing a dict + # to allow in-place modification by _convert_key_to_str, but + # validators should not have side effects once that backwards + # compatibility layer is removed, so we can validate a temporary + # dict {key: value} and directly set using those rather than + # looping over data. + data = {key: ret} + self._validate(data) + with self._suspend_sync: + for key, value in data.items(): + self._data[key] = value + return ret diff --git a/signac/synced_collections/data_types/synced_list.py b/signac/synced_collections/data_types/synced_list.py new file mode 100644 index 000000000..cb036b8a2 --- /dev/null +++ b/signac/synced_collections/data_types/synced_list.py @@ -0,0 +1,267 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Implements the :class:`SyncedList`. + +This implements a list-like data structure that also conforms to the +:class:`~.SyncedCollection` API and can be combined with any backend type to +give a list-like API to a synchronized data structure. +""" + +from collections.abc import MutableSequence, Sequence + +from ..numpy_utils import ( + _convert_numpy, + _is_atleast_1d_numpy_array, + _numpy_cache_blocklist, +) +from ..utils import AbstractTypeResolver +from .synced_collection import SyncedCollection, _sc_resolver + +# Identifies sequences, which are the base type for this class. +_sequence_resolver = AbstractTypeResolver( + { + "SEQUENCE": ( + lambda obj: (isinstance(obj, Sequence) and not isinstance(obj, str)) + or _is_atleast_1d_numpy_array(obj) + ), + }, + cache_blocklist=_numpy_cache_blocklist, +) + + +class SyncedList(SyncedCollection, MutableSequence): + r"""Implementation of list data structure. + + The SyncedList inherits from :class:`~synced_collection.SyncedCollection` + and :class:`~collections.abc.MutableSequence`. Therefore, it behaves similar + to a :class:`list`. + + Parameters + ---------- + data : Sequence, optional + The initial data to populate the list. If ``None``, defaults to + ``[]`` (Default value = None). + \*args : + Positional arguments forwarded to parent constructors. + \*\*kwargs : + Keyword arguments forwarded to parent constructors. + + Warnings + -------- + While the :class:`SyncedList` object behaves like a :class:`list`, there + are important distinctions to remember. In particular, because operations + are reflected as changes to an underlying backend, copying (even deep + copying) a :class:`SyncedList` instance may exhibit unexpected behavior. If + a true copy is required, you should use the `_to_base()` method to get a + :class:`list` representation, and if necessary construct a new + :class:`SyncedList`. + + """ + + # The _validate parameter is an optimization for internal use only. This + # argument will be passed by _from_base whenever an unsynced collection is + # being recursively converted, ensuring that validation only happens once. + def __init__(self, data=None, _validate=True, *args, **kwargs): + super().__init__(*args, **kwargs) + if data is None: + self._data = [] + else: + if _validate: + self._validate(data) + data = _convert_numpy(data) + with self._suspend_sync: + self._data = [ + self._from_base(data=value, parent=self) for value in data + ] + + @classmethod + def is_base_type(cls, data): + """Check whether the data is an non-string Sequence. + + Parameters + ---------- + data : Any + Data to be checked + + Returns + ------- + bool + + """ + return _sequence_resolver.get_type(data) == "SEQUENCE" + + def _to_base(self): + """Convert the SyncedList object to a :class:`list`. + + Returns + ------- + list + An equivalent raw :class:`list`. + + """ + converted = [] + for value in self._data: + switch_type = _sc_resolver.get_type(value) + if switch_type == "SYNCEDCOLLECTION": + converted.append(value._to_base()) + else: + converted.append(value) + return converted + + def _update(self, data=None, _validate=False): + """Update the in-memory representation to match the provided data. + + The purpose of this method is to update the SyncedCollection to match + the data in the underlying resource. The result of calling this method + should be that ``self == data``. The reason that this method is + necessary is that SyncedCollections can be nested, and nested + collections must also be instances of SyncedCollection so that + synchronization occurs even when nested structures are modified. + Recreating the full nested structure every time data is reloaded from + file is highly inefficient, so this method performs an in-place update + that only changes entries that need to be changed. + + Parameters + ---------- + data : collections.abc.Sequence + The data to be assigned to this list. If ``None``, the data is left + unchanged (Default value = None). + _validate : bool + If True, the data will not be validated (Default value = False). + + """ + if data is None: + # If no data is passed, take no action. + pass + elif _sequence_resolver.get_type(data) == "SEQUENCE": + with self._suspend_sync: + # This loop is optimized based on common usage patterns: + # insertion and removal at the end of a list. Inserting or + # removing in the middle will result in extra conversion + # operations for all subsequent items. In the worst case, + # inserting at the beginning will require reconverting all + # elements of the data. + for i in range(min(len(self), len(data))): + if data[i] == self._data[i]: + continue + if _sc_resolver.get_type(self._data[i]) == "SYNCEDCOLLECTION": + try: + self._data[i]._update(data[i]) + continue + except ValueError: + pass + if not _validate: + self._validate(data[i]) + self._data[i] = self._from_base(data[i], parent=self) + + if len(self._data) > len(data): + self._data = self._data[: len(data)] + else: + new_data = data[len(self) :] + if not _validate: + self._validate(new_data) + self.extend(new_data) + else: + raise ValueError( + "Unsupported type: {}. The data must be a non-string sequence or None.".format( + type(data) + ) + ) + + def reset(self, data): + """Update the instance with new data. + + Parameters + ---------- + data : non-string Sequence + Data to update the instance. + + Raises + ------ + ValueError + If the data is not a non-string sequence. + + """ + data = _convert_numpy(data) + if _sequence_resolver.get_type(data) == "SEQUENCE": + self._update(data) + with self._thread_lock: + self._save() + else: + raise ValueError( + "Unsupported type: {}. The data must be a non-string sequence or None.".format( + type(data) + ) + ) + + def __setitem__(self, key, value): + self._validate(value) + with self._load_and_save, self._suspend_sync: + self._data[key] = self._from_base(data=value, parent=self) + + def __reversed__(self): + self._load() + return reversed(self._data) + + def __iadd__(self, iterable): + # Convert input to a list so that iterators work as well as iterables. + iterable_data = list(iterable) + self._validate(iterable_data) + with self._load_and_save, self._suspend_sync: + self._data += [ + self._from_base(data=value, parent=self) for value in iterable_data + ] + return self + + def insert(self, index, item): # noqa: D102 + self._validate(item) + with self._load_and_save, self._suspend_sync: + self._data.insert(index, self._from_base(data=item, parent=self)) + + def append(self, item): # noqa: D102 + self._validate(item) + with self._load_and_save, self._suspend_sync: + self._data.append(self._from_base(data=item, parent=self)) + + def extend(self, iterable): # noqa: D102 + # Convert iterable to a list to ensure generators are exhausted only once + iterable_data = list(iterable) + self._validate(iterable_data) + with self._load_and_save, self._suspend_sync: + self._data.extend( + [self._from_base(data=value, parent=self) for value in iterable_data] + ) + + def remove(self, value): # noqa: D102 + with self._load_and_save, self._suspend_sync: + self._data.remove(self._from_base(data=value, parent=self)) + + def clear(self): # noqa: D102 + self._data = [] + with self._thread_lock: + self._save() + + def __lt__(self, other): + if isinstance(other, type(self)): + return self() < other() + else: + return self() > other + + def __le__(self, other): + if isinstance(other, type(self)): + return self() <= other() + else: + return self() <= other + + def __gt__(self, other): + if isinstance(other, type(self)): + return self() > other() + else: + return self() > other + + def __ge__(self, other): + if isinstance(other, type(self)): + return self() >= other() + else: + return self() >= other diff --git a/signac/synced_collections/errors.py b/signac/synced_collections/errors.py new file mode 100644 index 000000000..005e56597 --- /dev/null +++ b/signac/synced_collections/errors.py @@ -0,0 +1,49 @@ +# Copyright (c) 2017 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Errors raised by synced collections.""" + + +class BufferException(RuntimeError): + """Raised when any exception related to buffering occurs.""" + + +class BufferedError(BufferException): + """Raised when an error occured while flushing one or more buffered files. + + Attribute + --------- + files : dict + A dictionary of names that caused issues during the flush operation, + mapped to a possible reason for the issue or None in case that it + cannot be determined. + """ + + def __init__(self, files): + self.files = files + + def __str__(self): + return "{}({})".format(type(self).__name__, self.files) + + +class MetadataError(BufferException): + """Raised when metadata check fails. + + The contents of this file in the buffer can be accessed via the + `buffer_contents` attribute of this exception. + """ + + def __init__(self, filename, contents): + self.filename = filename + self.buffer_contents = contents + + def __str__(self): + return f"{self.filename} appears to have been externally modified." + + +class KeyTypeError(TypeError): + """Raised when a user uses a key of invalid type.""" + + +class InvalidKeyError(ValueError): + """Raised when a user uses a non-conforming key.""" diff --git a/signac/synced_collections/numpy_utils.py b/signac/synced_collections/numpy_utils.py new file mode 100644 index 000000000..bb346894d --- /dev/null +++ b/signac/synced_collections/numpy_utils.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Define utilities for handling NumPy arrays.""" + +import warnings + +try: + import numpy + + NUMPY = True + _numpy_cache_blocklist = (numpy.ndarray,) +except ImportError: + NUMPY = False + _numpy_cache_blocklist = None # type: ignore + + +class NumpyConversionWarning(UserWarning): + """Warning raised when NumPy data is converted.""" + + +NUMPY_CONVERSION_WARNING = ( + "Any numpy types provided will be transparently converted to the " + "closest base Python equivalents." +) + + +def _convert_numpy(data): + """Convert numpy data types to the corresponding base data types. + + 0d numpy arrays and numpy scalars are converted to their corresponding + primitive types, while other numpy arrays are converted to lists. If ``data`` + is not a numpy data type, this function is a no-op. + """ + # Initially performing one isinstance check is faster since most of the + # time the inputs are not numpy arrays, preventing additional function + # calls inside the if statement. + if NUMPY and isinstance(data, (numpy.ndarray, numpy.number, numpy.bool_)): + if isinstance(data, numpy.ndarray): + # tolist will return a scalar for 0d arrays, so there's no need to + # special-case that check. 1-element 1d arrays should remain + # arrays, i.e. np.array([1]) should become [1], not 1. + warnings.warn(NUMPY_CONVERSION_WARNING, NumpyConversionWarning) + return data.tolist() + else: + warnings.warn(NUMPY_CONVERSION_WARNING, NumpyConversionWarning) + return data.item() + return data + + +def _is_atleast_1d_numpy_array(data): + """Check if an object is a nonscalar numpy array. + + The synced collections framework must differentiate 0d numpy arrays from + other arrays because they are mapped to scalars while >0d arrays are mapped + to (synced) lists. + + Returns + ------- + bool + Whether or not the input is a numpy array with at least 1 dimension. + """ + return NUMPY and isinstance(data, numpy.ndarray) and data.ndim > 0 + + +def _is_numpy_scalar(data): + """Check if an object is a numpy scalar. + + This function is designed for use in situations where _convert_numpy has + already been applied (if necessary), so 0d arrays are not considered + scalars here. + + Returns + ------- + bool + Whether or not the input is a numpy scalar type. + """ + return NUMPY and ( + (isinstance(data, (numpy.number, numpy.bool_))) + or (isinstance(data, numpy.ndarray) and data.ndim == 0) + ) + + +def _is_complex(data): + """Check if an object is complex. + + This function works for both numpy raw Python data types. + + Returns + ------- + bool + Whether or not the input is a complex number. + """ + return (NUMPY and numpy.iscomplex(data).any()) or (isinstance(data, complex)) diff --git a/signac/synced_collections/utils.py b/signac/synced_collections/utils.py new file mode 100644 index 000000000..0610bef20 --- /dev/null +++ b/signac/synced_collections/utils.py @@ -0,0 +1,238 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Define common utilities.""" + +from json import JSONEncoder +from typing import Any, Dict + +from .numpy_utils import _convert_numpy, _is_numpy_scalar + + +class AbstractTypeResolver: + """Mapping between recognized types and their abstract parents. + + Synced collections are heavily reliant on checking the types of objects to + determine the appropriate type of behavior in various scenarios. For maximum + generality, most of these checks use the ABCs defined in :py:mod:`collections.abc`. + The price of this flexibility is that `isinstance` checks with these classes + are very slow because the ``__instancecheck__`` hooks are implemented in pure + Python and require checking many different cases. + + Rather than attempting to directly optimize this behavior, this class provides + a workaround by which we can amortize the cost of type checks. Given a set + of types that must be resolved and a way to identify each of these (which + may be expensive), it maintains a local cache of all instances of a given + type that have previously been observed. This reduces the cost of type checking + to a simple ``dict`` lookup, except for the first time a new type is observed. + + Parameters + ---------- + abstract_type_identifiers : Mapping + A mapping from a string identifier for a group of types (e.g. ``"MAPPING"``) + to a callable that can be used to identify that type. Due to insertion order + guarantees of dictionaries in Python>=3.6 (officially 3.7), it may be beneficial + to order this dictionary with the most frequently occuring types first. + However, unless users have many different concrete types implementing + the same abstract interface (e.g. many Mapping types identified via + ``isinstance(obj, Mapping)``), any performance gain should be negligible + since the callables will only be executed once per type. + cache_blocklist : Sequence, optional + A sequence of string identifiers from ``abstract_type_identifiers`` that + should not be cached. If there are cases where objects of the same type + would be classified into separate groups based on the callables in + ``abstract_type_identifiers``, this argument allows users to specify that + this type should not be cached. This argument should be used sparingly + because performance will quickly degrade if many calls to + :meth:`get_type` are with types that cannot be cached. The identifiers + (keys in ``abstract_type_identifiers``) corresponding to elements of the + blocklist should be placed first in the ``abstract_type_identifiers`` + dictionary since they will never be cached and are therefore the most + likely callables to be used repeatedly (Default value = None). + + Attributes + ---------- + abstract_type_identifiers : Dict[str, Callable[Any, bool]] + A mapping from string identifiers for an abstract type to callables that + accepts an object and returns True if the object is of the key type and + False if not. + type_map : Dict[Type, str] + A mapping from concrete types to the corresponding named abstract type + from :attr:`~.abstract_type_identifiers`. + + """ + + def __init__(self, abstract_type_identifiers, cache_blocklist=None): + self.abstract_type_identifiers = abstract_type_identifiers + self.type_map = {} + self.cache_blocklist = cache_blocklist if cache_blocklist is not None else () + + def get_type(self, obj): + """Get the type string corresponding to this data type. + + Parameters + ---------- + obj : Any + Any object whose type to check + + Returns + ------- + str + The name of the type, where valid types are the keys of the dict + argument to the constructor. If the object's type cannot be identified, + will return ``None``. + + """ + obj_type = type(obj) + enum_type = None + try: + enum_type = self.type_map[obj_type] + except KeyError: + for data_type, id_func in self.abstract_type_identifiers.items(): + if id_func(obj): + enum_type = data_type + break + if obj_type not in self.cache_blocklist: + self.type_map[obj_type] = enum_type + + return enum_type + + +def default(o: Any) -> Dict[str, Any]: # noqa: D102 + """Get a JSON-serializable version of compatible types. + + This function is suitable for use with JSON-serialization tools as a way to + serialize :class:`~.SyncedCollection` objects and NumPy arrays. It will + attempt to obtain a JSON-serializable representation of an object that is + otherwise not serializable by attempting to access its ``_data`` attribute. + + + Warnings + -------- + - JSON encoding of numpy arrays is not invertible; once encoded, reloading + the data will result in converting arrays to lists and numpy numbers into + ints or floats. + - This function assumes that the in-memory data for a SyncedCollection is + up-to-date. If the data has been changed on disk without updating the + collection, or if this function is used to serialize the data before any + method is invoked that would load the data from disk, the resulting + serialized data may be incorrect. + + """ + # NumPy converters return the data unchanged. + converted_o = _convert_numpy(o) + + # NumPy arrays will be converted to lists, then recursively parsed by the + # JSON encoder, so we only have to handle the case where we have a scalar + # type at the bottom level that can't be converted to a Python scalar. + if _is_numpy_scalar(converted_o): + raise ValueError( + "In order for a NumPy type to be JSON-encoded, it must have a corresponding " + "Python type. All other types, such as NumPy extended-precision types, must " + "be converted by the user. Note that the existence of a corresponding type " + "is a necessary but not sufficient condition because not all Python types " + "can be JSON-encoded. For instance, complex numpy values can be converted " + "to Python complex values, but these still cannot be JSON-encoded." + ) + + if converted_o is o: + try: + return o._data + except AttributeError as e: + raise TypeError from e + else: + return converted_o + + +class SyncedCollectionJSONEncoder(JSONEncoder): + """A :class:`json.JSONEncoder` that handles objects encodeable using :func:`~.default`. + + Warnings + -------- + - JSON encoding of numpy arrays is not invertible; once encoded, reloading + the data will result in converting arrays to lists and numpy numbers into + ints or floats. + - This class assumes that the in-memory data for a SyncedCollection is + up-to-date. If the data has been changed on disk without updating the + collection, or if this class is used to serialize the data before any + method of the collection is invoked that would load the data from disk, + the resulting serialized data may be incorrect. + + """ + + def default(self, o: Any) -> Dict[str, Any]: # noqa: D102 + try: + return default(o) + except TypeError: + # Call the super method, which raises a TypeError if it cannot + # encode the object. + return super().default(o) + + +class _NullContext: + """A nullary context manager. + + There are various cases where we sometimes want to perform a task within a + particular context, but at other times we wish to ignore that context. The + most obvious example is a lock for threading: since + :class:`~.SyncedCollection` allows multithreading support to be enabled or + disabled, it is important to be able to write code that is agnostic to + whether or not a mutex must be acquired prior to executing a task. Locks + support the context manager protocol and are used in that manner throughout + the code base, so the most transparent way to disable buffering is to + create a nullary context manager that can be placed as a drop-in + replacement for the lock so that all other code can handle this in a + transparent manner. + """ + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __call__(self): + """Allow usage of the context in a function-like manner.""" + return self + + +class _CounterContext: + """A context manager that maintains a total entry count. + + This class simply contains an internal counter that is incremented on every + entrance and decremented on every exit. It is also truthy and only evaluates + to True if the count is greater than 0. + """ + + def __init__(self): + self._count = 0 + + def __enter__(self): + self._count += 1 + + def __exit__(self, exc_type, exc_val, exc_tb): + self._count -= 1 + + def __bool__(self): + return self._count > 0 + + def __call__(self): + """Allow usage of the context in a function-like manner.""" + return self + + +class _CounterFuncContext(_CounterContext): + """A counter that performs some operation whenever the counter hits zero. + + This class maintains a counter, and also accepts an arbitrary nullary + callable to be executed anytime the context exits and the counter hits zero. + """ + + def __init__(self, func): + super().__init__() + self._func = func + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + if not self: + self._func() diff --git a/signac/synced_collections/validators.py b/signac/synced_collections/validators.py new file mode 100644 index 000000000..910c24740 --- /dev/null +++ b/signac/synced_collections/validators.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Validators for SyncedCollection API. + +A validator is any callable that raises Exceptions when called with invalid data. +Validators should act recursively for nested data structures and should not +return any values, only raise errors. This module implements built-in validators, +but client code is free to implement and add additioal validators to collection +types as needed. +""" + +from collections.abc import Mapping, Sequence + +from .errors import InvalidKeyError, KeyTypeError +from .numpy_utils import ( + _is_atleast_1d_numpy_array, + _is_complex, + _is_numpy_scalar, + _numpy_cache_blocklist, +) +from .utils import AbstractTypeResolver + +_no_dot_in_key_type_resolver = AbstractTypeResolver( + { + "MAPPING": lambda obj: isinstance(obj, Mapping), + "SEQUENCE": lambda obj: isinstance(obj, Sequence) and not isinstance(obj, str), + } +) + + +def no_dot_in_key(data): + """Raise an exception if there is a dot (``.``) in a mapping's key. + + Parameters + ---------- + data + Data to validate. + + Raises + ------ + KeyTypeError + If key data type is not supported. + InvalidKeyError + If the key contains invalid characters or is otherwise malformed. + + """ + VALID_KEY_TYPES = (str, int, bool, type(None)) + + switch_type = _no_dot_in_key_type_resolver.get_type(data) + + if switch_type == "MAPPING": + for key, value in data.items(): + if isinstance(key, str): + if "." in key: + raise InvalidKeyError( + f"Mapping keys may not contain dots ('.'): {key}" + ) + # TODO: Make it an error to have a non-str key here in signac 2.0. + elif not isinstance(key, VALID_KEY_TYPES): + raise KeyTypeError( + f"Mapping keys must be str, int, bool or None, not {type(key).__name__}" + ) + no_dot_in_key(value) + elif switch_type == "SEQUENCE": + for value in data: + no_dot_in_key(value) + + +def require_string_key(data): + """Raise an exception if key in a mapping is not a string. + + Almost all supported backends require string keys. + + Parameters + ---------- + data + Data to validate. + + Raises + ------ + KeyTypeError + If key type is not a string. + + """ + # Reuse the type resolver here since it has the same groupings. + switch_type = _no_dot_in_key_type_resolver.get_type(data) + + if switch_type == "MAPPING": + for key, value in data.items(): + if not isinstance(key, str): + raise KeyTypeError( + f"Mapping keys must be str, not {type(key).__name__}" + ) + require_string_key(value) + elif switch_type == "NON_STR_SEQUENCE": + for value in data: + require_string_key(value) + + +_json_format_validator_type_resolver = AbstractTypeResolver( + { + # We identify >0d numpy arrays as sequences for validation purposes. + "SEQUENCE": lambda obj: (isinstance(obj, Sequence) and not isinstance(obj, str)) + or _is_atleast_1d_numpy_array(obj), + "NUMPY": lambda obj: _is_numpy_scalar(obj), + "BASE": lambda obj: isinstance(obj, (str, int, float, bool, type(None))), + "MAPPING": lambda obj: isinstance(obj, Mapping), + }, + cache_blocklist=_numpy_cache_blocklist, +) + + +def json_format_validator(data): + """Validate input data can be serialized to JSON. + + Parameters + ---------- + data + Data to validate. + + Raises + ------ + KeyTypeError + If key data type is not supported. + TypeError + If the data type of ``data`` is not supported. + + """ + switch_type = _json_format_validator_type_resolver.get_type(data) + + if switch_type == "BASE": + return + elif switch_type == "MAPPING": + for key, value in data.items(): + if not isinstance(key, str): + raise KeyTypeError(f"Keys must be str, not {type(key).__name__}") + json_format_validator(value) + elif switch_type == "SEQUENCE": + for value in data: + json_format_validator(value) + elif switch_type == "NUMPY": + if _is_numpy_scalar(data.item()): + raise TypeError("NumPy extended precision types are not JSON serializable.") + elif _is_complex(data): + raise TypeError("Complex numbers are not JSON serializable.") + else: + raise TypeError( + f"Object of type {type(data).__name__} is not JSON serializable" + ) diff --git a/tests/conftest.py b/tests/conftest.py index d0deffbac..b370f15f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import uuid from contextlib import contextmanager import pytest @@ -13,3 +14,8 @@ def deprecated_in_version(version_string): yield else: yield + + +@pytest.fixture +def testdata(): + return str(uuid.uuid4()) diff --git a/tests/test_buffered_mode.py b/tests/test_buffered_mode.py index 7a7d5f586..257a20708 100644 --- a/tests/test_buffered_mode.py +++ b/tests/test_buffered_mode.py @@ -13,7 +13,8 @@ from test_project import TestProjectBase import signac -from signac.errors import BufferedFileError, BufferException, Error +from signac.errors import BufferedFileError, Error +from signac.synced_collections.errors import BufferedError PYPY = "PyPy" in platform.python_implementation() @@ -72,6 +73,10 @@ def test_basic_and_nested(self): assert job.doc.a == 2 assert job.doc.a == 2 + # Remove this test in signac 2.0. + @pytest.mark.xfail( + reason="The new SyncedCollection does not implement force_write." + ) def test_buffered_mode_force_write(self): with signac.buffered(force_write=False): with signac.buffered(force_write=False): @@ -88,6 +93,10 @@ def test_buffered_mode_force_write(self): pass assert not signac.is_buffered() + # Remove this test in signac 2.0. + @pytest.mark.xfail( + reason="The new SyncedCollection does not implement force_write." + ) def test_buffered_mode_force_write_with_file_modification(self): job = self.project.open_job(dict(a=0)) job.init() @@ -114,8 +123,9 @@ def test_buffered_mode_force_write_with_file_modification(self): file.write(json.dumps({"a": x}).encode()) assert job.doc.a == (not x) - @pytest.mark.skipif( - not ABLE_TO_PREVENT_WRITE, reason="unable to trigger permission error" + # Remove this test in signac 2.0. + @pytest.mark.xfail( + reason="The new SyncedCollection does not implement force_write." ) def test_force_write_mode_with_permission_error(self): job = self.project.open_job(dict(a=0)) @@ -145,25 +155,16 @@ def test_buffered_mode_change_buffer_size(self): assert signac.get_buffer_size() == 12 assert not signac.is_buffered() - with pytest.raises(TypeError): - with signac.buffered(buffer_size=True): - pass assert not signac.is_buffered() with signac.buffered(buffer_size=12): assert signac.buffered() assert signac.get_buffer_size() == 12 - with signac.buffered(buffer_size=12): + with signac.buffered(): assert signac.buffered() assert signac.get_buffer_size() == 12 assert not signac.is_buffered() - with pytest.raises(BufferException): - with signac.buffered(buffer_size=12): - assert signac.buffered() - assert signac.get_buffer_size() == 12 - with signac.buffered(buffer_size=14): - pass def test_integration(self): def routine(): @@ -240,7 +241,7 @@ def routine(): assert job2.doc.a == (not x) assert job.doc.a == (not x) - with pytest.raises(BufferedFileError) as cm: + with pytest.raises(BufferedError) as cm: with signac.buffered(): assert job.doc.a == (not x) job.doc.a = x diff --git a/tests/test_find_command_line_interface.py b/tests/test_find_command_line_interface.py index e47120cf6..8207f5fb0 100644 --- a/tests/test_find_command_line_interface.py +++ b/tests/test_find_command_line_interface.py @@ -7,11 +7,11 @@ from contextlib import contextmanager from io import StringIO from itertools import chain +from json import JSONDecodeError import pytest from signac.contrib.filterparse import parse_filter_arg, parse_simple -from signac.core.json import JSONDecodeError FILTERS = [ {"a": 0}, diff --git a/tests/test_job.py b/tests/test_job.py index f083f28e3..d1d6a383c 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -16,12 +16,12 @@ import signac.common.config import signac.contrib -from signac import Project # noqa: F401 -from signac.contrib.job import Job # noqa: F401 +from signac.contrib.errors import JobsCorruptedError +from signac.contrib.job import Job from signac.errors import DestinationExistsError, InvalidKeyError, KeyTypeError try: - import h5py # noqa + import h5py # noqa: F401 H5PY = True except ImportError: @@ -446,11 +446,14 @@ class A: assert str(key) in job.sp def test_invalid_sp_key_types(self): - job = self.open_job(dict(invalid_key=True)).init() - class A: pass + with pytest.raises(KeyTypeError): + self.open_job({A(): True}).init() + + job = self.open_job(dict(invalid_key=True)).init() + for key in (0.0, A(), (1, 2, 3)): with pytest.raises(KeyTypeError): job.sp[key] = "test" @@ -537,6 +540,9 @@ def test_chained_init(self): assert os.path.exists(os.path.join(job.workspace(), job.FN_MANIFEST)) def test_construction(self): + from signac import Project # noqa: F401 + + # The eval statement needs to have Project available job = self.open_job(test_token) job2 = eval(repr(job)) assert job == job2 @@ -627,8 +633,8 @@ def test_corrupt_workspace(self): job2 = self.open_job(test_token) try: logging.disable(logging.ERROR) - # Detects the corrupted manifest and overwrites with valid data - job2.init() + with pytest.raises(JobsCorruptedError): + job2.init() finally: logging.disable(logging.NOTSET) job2.init(force=True) @@ -1010,6 +1016,40 @@ def test_reset_statepoint_job(self): with pytest.raises(DestinationExistsError): src_job.reset_statepoint(dst) + @pytest.mark.skipif(not H5PY, reason="test requires the h5py package") + def test_reset_statepoint_job_lazy_access(self): + key = "move_job" + d = testdata() + src = test_token + dst = dict(test_token) + dst["dst"] = True + src_job = self.open_job(src) + src_job.document[key] = d + assert key in src_job.document + assert len(src_job.document) == 1 + src_job.data[key] = d + assert key in src_job.data + assert len(src_job.data) == 1 + # Clear the project's state point cache to force lazy load + self.project._sp_cache.clear() + src_job_by_id = self.open_job(id=src_job.id) + # Check that the state point will be instantiated lazily during the + # call to reset_statepoint + assert src_job_by_id._statepoint_requires_init + src_job_by_id.reset_statepoint(dst) + src_job = self.open_job(src) + dst_job = self.open_job(dst) + assert key in dst_job.document + assert len(dst_job.document) == 1 + assert key not in src_job.document + assert key in dst_job.data + assert len(dst_job.data) == 1 + assert key not in src_job.data + with pytest.raises(RuntimeError): + src_job.reset_statepoint(dst) + with pytest.raises(DestinationExistsError): + src_job.reset_statepoint(dst) + @pytest.mark.skipif(not H5PY, reason="test requires the h5py package") def test_reset_statepoint_project(self): key = "move_job" diff --git a/tests/test_numpy_integration.py b/tests/test_numpy_integration.py index 5e564d70a..65b14da4a 100644 --- a/tests/test_numpy_integration.py +++ b/tests/test_numpy_integration.py @@ -4,6 +4,8 @@ import pytest from test_project import TestProjectBase +from signac.synced_collections.numpy_utils import NumpyConversionWarning + try: import numpy # noqa import numpy.testing @@ -19,8 +21,10 @@ def test_store_number_in_sp_and_doc(self): for i in range(10): a = numpy.float32(i) if i % 2 else numpy.float64(i) b = numpy.float64(i) if i % 2 else numpy.float32(i) - job = self.project.open_job(dict(a=a)) - job.doc.b = b + with pytest.warns(NumpyConversionWarning): + job = self.project.open_job(dict(a=a)) + with pytest.warns(NumpyConversionWarning): + job.doc.b = b numpy.testing.assert_equal(job.doc.b, b) for i, job in enumerate(sorted(self.project, key=lambda job: job.sp.a)): assert job.sp.a == i @@ -28,7 +32,8 @@ def test_store_number_in_sp_and_doc(self): def test_store_array_in_sp(self): for i in range(10): - self.project.open_job(dict(a=numpy.array([i]))).init() + with pytest.warns(NumpyConversionWarning): + self.project.open_job(dict(a=numpy.array([i]))).init() for i, job in enumerate(sorted(self.project, key=lambda job: job.sp.a)): assert [i] == job.sp.a assert numpy.array([i]) == job.sp.a @@ -36,7 +41,8 @@ def test_store_array_in_sp(self): def test_store_array_in_doc(self): for i in range(10): job = self.project.open_job(dict(a=i)) - job.doc.array = numpy.ones(3) * i + with pytest.warns(NumpyConversionWarning): + job.doc.array = numpy.ones(3) * i numpy.testing.assert_equal(job.doc.array, numpy.ones(3) * i) for i, job in enumerate(sorted(self.project, key=lambda job: job.sp.a)): assert i == job.sp.a @@ -47,7 +53,8 @@ def test_store_zero_dim_array_in_sp(self): # Zero-dimensional arrays have size 1, and their tolist() method # returns a single value. value = 1.0 - job = self.project.open_job(dict(a=numpy.array(value))).init() + with pytest.warns(NumpyConversionWarning): + job = self.project.open_job(dict(a=numpy.array(value))).init() assert value == job.sp.a assert numpy.array(value) == job.sp.a @@ -56,7 +63,8 @@ def test_store_zero_dim_array_in_doc(self): # returns a single value. value = 1.0 job = self.project.open_job(dict(a=1)).init() - job.doc.array = numpy.array(value) + with pytest.warns(NumpyConversionWarning): + job.doc.array = numpy.array(value) numpy.testing.assert_equal(job.doc.array, numpy.array(value)) assert value == job.doc.array assert numpy.array(value) == job.doc.array diff --git a/tests/test_synced_collections/attr_dict_test.py b/tests/test_synced_collections/attr_dict_test.py new file mode 100644 index 000000000..54f19d353 --- /dev/null +++ b/tests/test_synced_collections/attr_dict_test.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +"""Tests to be used for dictionaries supporting attr-based access.""" + +import pytest + +from signac.errors import InvalidKeyError, KeyTypeError + + +class AttrDictTest: + def test_attr_dict(self, synced_collection, testdata): + key = "test" + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert key in synced_collection + assert synced_collection[key] == testdata + assert synced_collection.get(key) == testdata + assert synced_collection.test == testdata + del synced_collection.test + assert len(synced_collection) == 0 + assert key not in synced_collection + key = "test2" + synced_collection.test2 = testdata + assert len(synced_collection) == 1 + assert key in synced_collection + assert synced_collection[key] == testdata + assert synced_collection.get(key) == testdata + assert synced_collection.test2 == testdata + with pytest.raises(AttributeError): + synced_collection.not_exist + + # deleting a protected attribute + synced_collection._load() + del synced_collection._root + # deleting _root will lead to recursion as _root is treated as key + # _load() will check for _root and __getattr__ will call __getitem__ + # which calls _load() + with pytest.raises(RecursionError): + synced_collection._load() + + def test_keys_with_dots(self, synced_collection): + with pytest.raises(InvalidKeyError): + synced_collection["a.b"] = None + with pytest.raises(KeyTypeError): + synced_collection[0.0] = None + + +class AttrListTest: + """Test that dicts contained in AttrList classes are AttrDicts.""" + + def test_attr_list(self, synced_collection, testdata): + synced_collection.append({}) + nested_synced_dict = synced_collection[0] + + key = "test" + nested_synced_dict[key] = testdata + assert len(nested_synced_dict) == 1 + assert key in nested_synced_dict + assert nested_synced_dict[key] == testdata + assert nested_synced_dict.get(key) == testdata + assert nested_synced_dict.test == testdata + del nested_synced_dict.test + assert len(nested_synced_dict) == 0 + assert key not in nested_synced_dict + key = "test2" + nested_synced_dict.test2 = testdata + assert len(nested_synced_dict) == 1 + assert key in nested_synced_dict + assert nested_synced_dict[key] == testdata + assert nested_synced_dict.get(key) == testdata + assert nested_synced_dict.test2 == testdata + with pytest.raises(AttributeError): + nested_synced_dict.not_exist diff --git a/tests/test_synced_collections/synced_collection_test.py b/tests/test_synced_collections/synced_collection_test.py new file mode 100644 index 000000000..2b269f8b4 --- /dev/null +++ b/tests/test_synced_collections/synced_collection_test.py @@ -0,0 +1,874 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +import platform +from collections.abc import MutableMapping, MutableSequence +from copy import deepcopy +from typing import Any, Tuple, Type + +import pytest + +from signac.errors import KeyTypeError +from signac.synced_collections import SyncedCollection +from signac.synced_collections.numpy_utils import NumpyConversionWarning + +PYPY = "PyPy" in platform.python_implementation() + + +try: + import numpy + + NUMPY = True + + NUMPY_INT_TYPES: Tuple[Type, ...] = ( + numpy.bool_, + numpy.byte, + numpy.ubyte, + numpy.short, + numpy.ushort, + numpy.intc, + numpy.uintc, + numpy.int_, + numpy.uint, + numpy.longlong, + numpy.ulonglong, + numpy.int8, + numpy.int16, + numpy.int32, + numpy.int64, + numpy.uint8, + numpy.uint16, + numpy.uint32, + numpy.uint64, + numpy.intp, + numpy.uintp, + ) + + NUMPY_FLOAT_TYPES: Tuple[Type, ...] = ( + numpy.half, + numpy.float16, + numpy.single, + numpy.longdouble, + numpy.float32, + numpy.float64, + numpy.float128, + numpy.float_, + ) + + NUMPY_COMPLEX_TYPES: Tuple[Type, ...] = ( + numpy.csingle, + numpy.cdouble, + numpy.clongdouble, + numpy.complex64, + numpy.complex128, + numpy.complex_, + ) + NUMPY_SHAPES: Tuple[Any, ...] = (None, (1,), (2,), (2, 2)) + + # Older numpy versions don't have the new rngs. + try: + rng = numpy.random.default_rng() + random_sample = rng.random + randint = rng.integers + except AttributeError: + random_sample = numpy.random.random_sample + randint = numpy.random.randint + +except ImportError: + NUMPY = False + NUMPY_INT_TYPES = () + NUMPY_FLOAT_TYPES = () + NUMPY_COMPLEX_TYPES = () + NUMPY_SHAPES = () + + +class SyncedCollectionTest: + """The parent for all synced collection tests. + + This class defines the standard APIs that are expected of all test subclasses. + Following these protocols allows different backends and data types to share + most test by just defining the expected additional variables. + """ + + def store(self, synced_collection, data): + """Directly store data to the backend using its own API. + + This method should bypass the synced_collection, just using it to get + information on the underlying resource (e.g. a file) and then saving to + it directly. This is used, for instance, to test data integrity validation. + """ + raise NotImplementedError("All backend tests must implement the store method.") + + @pytest.fixture + def synced_collection(self): + """Generate a synced collection of the appropriate type. + + The type of the synced collection should be specified by the _collection_type + class variable. + """ + raise NotImplementedError( + "All backend tests must implement a synced_collection fixture " + "that returns an empty instance." + ) + + @pytest.fixture + def base_collection(self): + """Generate a collection of the base data type. + + This fixture should generate a base collection (e.g. a list or a dict) + that can be used to set the data of a synced collection of the corresponding + type for comparison in tests. + """ + raise NotImplementedError( + "All data type tests must implement a base_collection fixture that " + "returns an instance populated with test data." + ) + + +class SyncedDictTest(SyncedCollectionTest): + @pytest.fixture + def base_collection(self): + return {"a": 0} + + def test_init(self, synced_collection): + assert len(synced_collection) == 0 + + def test_init_positional(self, synced_collection_positional): + assert len(synced_collection_positional) == 0 + + def test_isinstance(self, synced_collection): + assert isinstance(synced_collection, SyncedCollection) + assert isinstance(synced_collection, MutableMapping) + + def test_set_get(self, synced_collection, testdata): + key = "setget" + synced_collection.clear() + assert not bool(synced_collection) + assert len(synced_collection) == 0 + assert key not in synced_collection + synced_collection[key] = testdata + assert bool(synced_collection) + assert len(synced_collection) == 1 + assert key in synced_collection + assert synced_collection[key] == testdata + assert synced_collection.get(key) == testdata + + def test_set_get_explicit_nested(self, synced_collection, testdata): + key = "setgetexplicitnested" + synced_collection.setdefault("a", dict()) + child1 = synced_collection["a"] + child2 = synced_collection["a"] + assert child1 == child2 + assert isinstance(child1, type(child2)) + assert id(child1) == id(child2) + assert not child1 + assert not child2 + child1[key] = testdata + assert child1 + assert child2 + assert key in child1 + assert key in child2 + assert child1 == child2 + assert child1[key] == testdata + assert child2[key] == testdata + + def test_copy_value(self, synced_collection, testdata): + key = "copy_value" + key2 = "copy_value2" + assert key not in synced_collection + assert key2 not in synced_collection + synced_collection[key] = testdata + assert key in synced_collection + assert synced_collection[key] == testdata + assert key2 not in synced_collection + synced_collection[key2] = synced_collection[key] + assert key in synced_collection + assert synced_collection[key] == testdata + assert key2 in synced_collection + assert synced_collection[key2] == testdata + + def test_iter(self, synced_collection, testdata): + key1 = "iter1" + key2 = "iter2" + d = {key1: testdata, key2: testdata} + synced_collection.update(d) + assert key1 in synced_collection + assert key2 in synced_collection + for i, key in enumerate(synced_collection): + assert key in d + assert d[key] == synced_collection[key] + assert i == 1 + + def test_delete(self, synced_collection, testdata): + key = "delete" + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + del synced_collection[key] + assert len(synced_collection) == 0 + with pytest.raises(KeyError): + synced_collection[key] + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + del synced_collection["delete"] + assert len(synced_collection) == 0 + with pytest.raises(KeyError): + synced_collection[key] + + def test_update(self, synced_collection, testdata): + key = "update" + d = {key: testdata} + synced_collection.update(d) + assert len(synced_collection) == 1 + assert synced_collection[key] == d[key] + # upadte with no argument + synced_collection.update() + assert len(synced_collection) == 1 + assert synced_collection[key] == d[key] + # update using key as kwarg + synced_collection.update(update2=testdata) + assert len(synced_collection) == 2 + assert synced_collection["update2"] == testdata + # same key in other dict and as kwarg with different values + synced_collection.update({key: 1}, update=2) # here key is 'update' + assert len(synced_collection) == 2 + assert synced_collection[key] == 2 + # update using list of key and value pair + synced_collection.update([("update2", 1), ("update3", 2)]) + assert len(synced_collection) == 3 + assert synced_collection["update2"] == 1 + assert synced_collection["update3"] == 2 + + def test_pop(self, synced_collection, testdata): + key = "pop" + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + d1 = synced_collection.pop(key) + assert len(synced_collection) == 0 + assert testdata == d1 + with pytest.raises(KeyError): + synced_collection[key] + d2 = synced_collection.pop(key, "default") + assert len(synced_collection) == 0 + assert d2 == "default" + + def test_popitem(self, synced_collection, testdata): + key = "pop" + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + key1, d1 = synced_collection.popitem() + assert len(synced_collection) == 0 + assert key == key1 + assert testdata == d1 + with pytest.raises(KeyError): + synced_collection[key] + + def test_values(self, synced_collection, testdata): + data = {"value1": testdata, "value_nested": {"value2": testdata}} + synced_collection.reset(data) + assert "value1" in synced_collection + assert "value_nested" in synced_collection + for val in synced_collection.values(): + assert not isinstance(val, SyncedCollection) + assert val in data.values() + + def test_items(self, synced_collection, testdata): + data = {"item1": testdata, "item_nested": {"item2": testdata}} + synced_collection.reset(data) + assert "item1" in synced_collection + assert "item_nested" in synced_collection + for key, val in synced_collection.items(): + assert synced_collection[key] == data[key] + assert not isinstance(val, type(synced_collection)) + assert (key, val) in data.items() + + def test_setdefault(self, synced_collection, testdata): + key = "setdefault" + ret = synced_collection.setdefault(key, testdata) + assert ret == testdata + assert key in synced_collection + assert synced_collection[key] == testdata + ret = synced_collection.setdefault(key, 1) + assert ret == testdata + assert synced_collection[key] == testdata + + def test_reset(self, synced_collection, testdata): + key = "reset" + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + synced_collection.reset({"reset": "abc"}) + assert len(synced_collection) == 1 + assert synced_collection[key] == "abc" + + # invalid input + with pytest.raises(ValueError): + synced_collection.reset([0, 1]) + + def test_clear(self, synced_collection, testdata): + key = "clear" + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + synced_collection.clear() + assert len(synced_collection) == 0 + + def test_repr(self, synced_collection): + repr(synced_collection) + p = eval(repr(synced_collection)) + assert repr(p) == repr(synced_collection) + assert p == synced_collection + + def test_str(self, synced_collection): + str(synced_collection) == str(synced_collection()) + + def test_call(self, synced_collection, testdata): + key = "call" + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + assert isinstance(synced_collection(), dict) + assert not isinstance(synced_collection(), SyncedCollection) + + def recursive_convert(d): + return { + k: (recursive_convert(v) if isinstance(v, SyncedCollection) else v) + for k, v in d.items() + } + + assert synced_collection() == recursive_convert(synced_collection) + assert synced_collection() == {"call": testdata} + + def test_reopen(self, synced_collection, testdata): + key = "reopen" + synced_collection[key] = testdata + try: + synced_collection2 = deepcopy(synced_collection) + except TypeError: + # Ignore backends that don't support deepcopy. + return + synced_collection._save() + del synced_collection # possibly unsafe + synced_collection2._load() + assert len(synced_collection2) == 1 + assert synced_collection2[key] == testdata + + def test_update_recursive(self, synced_collection, testdata): + synced_collection["a"] = {"a": 1} + synced_collection["b"] = "test" + synced_collection["c"] = [0, 1, 2] + assert "a" in synced_collection + assert "b" in synced_collection + assert "c" in synced_collection + data = {"a": 1, "c": [0, 1, 3], "d": 1} + self.store(synced_collection, data) + assert synced_collection == data + + # Test multiple changes. kwargs should supersede the mapping + synced_collection.update({"a": 1, "b": 3}, a={"foo": "bar"}, d=[2, 3]) + assert synced_collection == { + "a": {"foo": "bar"}, + "b": 3, + "c": [0, 1, 3], + "d": [2, 3], + } + + # Test multiple changes using a sequence of key-value pairs. + synced_collection.update((("d", [{"bar": "baz"}, 1]), ("c", 1)), e=("a", "b")) + assert synced_collection == { + "a": {"foo": "bar"}, + "b": 3, + "c": 1, + "d": [{"bar": "baz"}, 1], + "e": ["a", "b"], + } + + # invalid data + data = [1, 2, 3] + self.store(synced_collection, data) + with pytest.raises(ValueError): + synced_collection._load() + + def test_copy_as_dict(self, synced_collection, testdata): + key = "copy" + synced_collection[key] = testdata + copy = dict(synced_collection) + del synced_collection + assert key in copy + assert copy[key] == testdata + + def test_nested_dict(self, synced_collection): + synced_collection["a"] = dict(a=dict()) + child1 = synced_collection["a"] + child2 = synced_collection["a"]["a"] + assert isinstance(child1, type(synced_collection)) + assert isinstance(child1, type(child2)) + + def test_nested_dict_with_list(self, synced_collection): + synced_collection["a"] = [1, 2, 3] + child1 = synced_collection["a"] + synced_collection["a"].append(dict(a=[1, 2, 3])) + child2 = synced_collection["a"][3] + child3 = synced_collection["a"][3]["a"] + assert isinstance(child2, type(synced_collection)) + assert isinstance(child1, type(child3)) + assert isinstance(child1, SyncedCollection) + assert isinstance(child3, SyncedCollection) + + def test_write_invalid_type(self, synced_collection, testdata): + class Foo: + pass + + key = "write_invalid_type" + synced_collection[key] = testdata + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + d2 = Foo() + with pytest.raises(TypeError): + synced_collection[key + "2"] = d2 + assert len(synced_collection) == 1 + assert synced_collection[key] == testdata + + def test_keys_str_type(self, synced_collection, testdata): + class MyStr(str): + pass + + for key in ("key", MyStr("key")): + synced_collection[key] = testdata + assert key in synced_collection + assert synced_collection[key] == testdata + + # TODO: This test should only be applied for backends where JSON-formatting + # is required. + def test_keys_invalid_type(self, synced_collection, testdata): + class A: + pass + + for key in (A(), (1, 2, 3)): + with pytest.raises(KeyTypeError): + synced_collection[key] = testdata + for key in ([], {}): + with pytest.raises(TypeError): + synced_collection[key] = testdata + + def test_multithreaded(self, synced_collection): + """Test multithreaded runs of synced dicts.""" + if not type(synced_collection)._supports_threading: + return + + from concurrent.futures import ThreadPoolExecutor + from json.decoder import JSONDecodeError + from threading import current_thread + + def set_value(sd): + sd[current_thread().name] = current_thread().name + + num_threads = 50 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(set_value, [synced_collection] * num_threads * 10)) + + assert len(synced_collection) == num_threads + + # Now clear the data and try again with multithreading disabled. Unless + # we're very unlucky, some of these threads should overwrite each + # other. + type(synced_collection).disable_multithreading() + synced_collection.clear() + + try: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(set_value, [synced_collection] * num_threads * 10)) + except (RuntimeError, JSONDecodeError): + # This line may raise an exception, or it may successfully complete + # but not modify all the expected data. If it raises an exception, + # then the underlying data is likely to be invalid, so we must + # clear it. + synced_collection.clear() + else: + # PyPy is fast enough that threads will frequently complete without + # being preempted, so this check is frequently invalidated. + if not PYPY: + assert len(synced_collection) != num_threads + + # For good measure, try reenabling multithreading and test to be safe. + type(synced_collection).enable_multithreading() + synced_collection.clear() + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(set_value, [synced_collection] * num_threads * 10)) + + assert len(synced_collection) == num_threads + + @pytest.mark.skipif(not NUMPY, reason="This test requires the numpy package.") + @pytest.mark.parametrize("dtype", NUMPY_INT_TYPES) + @pytest.mark.parametrize("shape", NUMPY_SHAPES) + def test_set_get_numpy_int_data(self, synced_collection, dtype, shape): + """Test setting scalar int types, which should always work.""" + try: + max_value = numpy.iinfo(dtype).max + except ValueError: + max_value = 1 + value = randint(max_value, dtype=dtype, size=shape) + + with pytest.warns(NumpyConversionWarning): + synced_collection["numpy_dtype_val"] = value + raw_value = value.item() if shape is None else value.tolist() + assert synced_collection["numpy_dtype_val"] == raw_value + + @pytest.mark.skipif(not NUMPY, reason="This test requires the numpy package.") + @pytest.mark.parametrize("dtype", NUMPY_FLOAT_TYPES) + @pytest.mark.parametrize("shape", NUMPY_SHAPES) + def test_set_get_numpy_float_data(self, synced_collection, dtype, shape): + """Test setting scalar float types, which work if a raw Python analog exists.""" + value = dtype(random_sample(shape)) + + # If casting via item does not give a base Python type, the number + # should fail to set correctly. + raw_value = value.item() if shape is None else value.tolist() + test_value = value.item(0) if isinstance(raw_value, list) else raw_value + has_corresponding_python_type = isinstance( + test_value, (numpy.number, numpy.bool_) + ) + + if has_corresponding_python_type: + with pytest.raises((ValueError, TypeError)), pytest.warns( + NumpyConversionWarning + ): + synced_collection["numpy_dtype_val"] = value + else: + with pytest.warns(NumpyConversionWarning): + synced_collection["numpy_dtype_val"] = value + assert synced_collection["numpy_dtype_val"] == raw_value + + @pytest.mark.skipif(not NUMPY, reason="This test requires the numpy package.") + @pytest.mark.parametrize("dtype", NUMPY_COMPLEX_TYPES) + @pytest.mark.parametrize("shape", NUMPY_SHAPES) + def test_set_get_numpy_complex_data(self, synced_collection, dtype, shape): + """Test setting scalar complex types, which should always fail.""" + # Note that the current behavior of this test is based on the fact that + # all backends rely on JSON-serialization (at least implicitly), even + # non-JSON backends. This test may have to be generalized if we add any + # backends that support other data, or if we want to test cases like + # ZarrCollection with a non-JSON codec (alternatives are supported, but + # not a priority to test here). + value = dtype(random_sample(shape)) + + with pytest.raises((ValueError, TypeError)), pytest.warns( + NumpyConversionWarning + ): + synced_collection["numpy_dtype_val"] = value + + +class SyncedListTest(SyncedCollectionTest): + @pytest.fixture + def base_collection(self): + return [0] + + def test_init(self, synced_collection): + assert len(synced_collection) == 0 + + def test_isinstance(self, synced_collection): + assert isinstance(synced_collection, MutableSequence) + assert isinstance(synced_collection, SyncedCollection) + + def test_set_get(self, synced_collection, testdata): + synced_collection.clear() + assert not bool(synced_collection) + assert len(synced_collection) == 0 + synced_collection.append(testdata) + assert bool(synced_collection) + assert len(synced_collection) == 1 + assert synced_collection[0] == testdata + synced_collection[0] = 1 + assert bool(synced_collection) + assert len(synced_collection) == 1 + assert synced_collection[0] == 1 + + def test_iter(self, synced_collection, testdata): + d = [testdata, 43] + synced_collection.extend(d) + for i in range(len(synced_collection)): + assert d[i] == synced_collection[i] + assert i == 1 + + def test_delete(self, synced_collection, testdata): + synced_collection.append(testdata) + assert len(synced_collection) == 1 + assert synced_collection[0] == testdata + del synced_collection[0] + assert len(synced_collection) == 0 + with pytest.raises(IndexError): + synced_collection[0] + + def test_extend(self, synced_collection, testdata): + d = [testdata] + synced_collection.extend(d) + assert len(synced_collection) == 1 + assert synced_collection[0] == d[0] + d1 = testdata + synced_collection += [d1] + assert len(synced_collection) == 2 + assert synced_collection[0] == d[0] + assert synced_collection[1] == d1 + + # Ensure generators are exhausted only once by extend + def data_generator(): + yield testdata + + synced_collection.extend(data_generator()) + assert len(synced_collection) == 3 + assert synced_collection[0] == d[0] + assert synced_collection[1] == d1 + assert synced_collection[2] == testdata + + # Ensure generators are exhausted only once by __iadd__ + def data_generator(): + yield testdata + + synced_collection += data_generator() + assert len(synced_collection) == 4 + assert synced_collection[0] == d[0] + assert synced_collection[1] == d1 + assert synced_collection[2] == testdata + assert synced_collection[3] == testdata + + def test_clear(self, synced_collection, testdata): + synced_collection.append(testdata) + assert len(synced_collection) == 1 + assert synced_collection[0] == testdata + synced_collection.clear() + assert len(synced_collection) == 0 + + def test_reset(self, synced_collection): + synced_collection.reset([1, 2, 3]) + assert len(synced_collection) == 3 + assert synced_collection == [1, 2, 3] + synced_collection.reset([3, 4]) + assert len(synced_collection) == 2 + assert synced_collection == [3, 4] + + # invalid inputs + with pytest.raises(ValueError): + synced_collection.reset({"a": 1}) + + with pytest.raises(ValueError): + synced_collection.reset(1) + + def test_insert(self, synced_collection, testdata): + synced_collection.reset([1, 2]) + assert len(synced_collection) == 2 + synced_collection.insert(1, testdata) + assert len(synced_collection) == 3 + assert synced_collection[1] == testdata + + def test_reversed(self, synced_collection): + data = [1, 2, 3] + synced_collection.reset([1, 2, 3]) + assert len(synced_collection) == 3 + assert synced_collection == data + for i, j in zip(reversed(synced_collection), reversed(data)): + assert i == j + + def test_remove(self, synced_collection): + synced_collection.reset([1, 2]) + assert len(synced_collection) == 2 + synced_collection.remove(1) + assert len(synced_collection) == 1 + assert synced_collection[0] == 2 + synced_collection.reset([1, 2, 1]) + synced_collection.remove(1) + assert len(synced_collection) == 2 + assert synced_collection[0] == 2 + assert synced_collection[1] == 1 + + def test_call(self, synced_collection): + synced_collection.reset([1, 2]) + assert len(synced_collection) == 2 + assert isinstance(synced_collection(), list) + assert not isinstance(synced_collection(), SyncedCollection) + assert synced_collection() == [1, 2] + + def test_update_recursive(self, synced_collection): + synced_collection.reset([{"a": 1}, "b", [1, 2, 3]]) + assert synced_collection == [{"a": 1}, "b", [1, 2, 3]] + data = ["a", "b", [1, 2, 4], "d"] + self.store(synced_collection, data) + assert synced_collection == data + data1 = ["a", "b"] + self.store(synced_collection, data1) + assert synced_collection == data1 + + # invalid data in file + data2 = {"a": 1} + self.store(synced_collection, data2) + with pytest.raises(ValueError): + synced_collection._load() + + def test_reopen(self, synced_collection, testdata): + try: + synced_collection2 = deepcopy(synced_collection) + except TypeError: + # Ignore backends that don't support deepcopy. + return + synced_collection.append(testdata) + synced_collection._save() + del synced_collection # possibly unsafe + synced_collection2._load() + assert len(synced_collection2) == 1 + assert synced_collection2[0] == testdata + + def test_copy_as_list(self, synced_collection, testdata): + synced_collection.append(testdata) + assert synced_collection[0] == testdata + copy = list(synced_collection) + del synced_collection + assert copy[0] == testdata + + def test_repr(self, synced_collection): + repr(synced_collection) + p = eval(repr(synced_collection)) + assert repr(p) == repr(synced_collection) + assert p == synced_collection + + def test_str(self, synced_collection): + str(synced_collection) == str(synced_collection()) + + def test_nested_list(self, synced_collection): + synced_collection.reset([1, 2, 3]) + synced_collection.append([2, 4]) + child1 = synced_collection[3] + child2 = synced_collection[3] + assert child1 == child2 + assert isinstance(child1, type(child2)) + assert isinstance(child1, type(synced_collection)) + assert id(child1) == id(child2) + child1.append(1) + assert child2[2] == child1[2] + assert child1 == child2 + assert len(synced_collection) == 4 + assert isinstance(child1, type(child2)) + assert isinstance(child1, type(synced_collection)) + assert id(child1) == id(child2) + del child1[0] + assert child1 == child2 + assert len(synced_collection) == 4 + assert isinstance(child1, type(child2)) + assert isinstance(child1, type(synced_collection)) + assert id(child1) == id(child2) + + def test_nested_list_with_dict(self, synced_collection): + synced_collection.reset([{"a": [1, 2, 3, 4]}]) + child1 = synced_collection[0] + child2 = synced_collection[0]["a"] + assert isinstance(child2, SyncedCollection) + assert isinstance(child1, SyncedCollection) + + def test_multithreaded(self, synced_collection): + """Test multithreaded runs of synced lists.""" + if not type(synced_collection)._supports_threading: + return + + from concurrent.futures import ThreadPoolExecutor + + def append_value(sl): + sl.append(0) + + num_threads = 50 + num_elements = num_threads * 10 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(append_value, [synced_collection] * num_elements)) + + assert len(synced_collection) == num_elements + + @pytest.mark.skipif(not NUMPY, reason="This test requires the numpy package.") + @pytest.mark.parametrize("dtype", NUMPY_INT_TYPES) + @pytest.mark.parametrize("shape", NUMPY_SHAPES) + def test_set_get_numpy_int_data(self, synced_collection, dtype, shape): + """Test setting scalar int types, which should always work.""" + try: + max_value = numpy.iinfo(dtype).max + except ValueError: + max_value = 1 + value = randint(max_value, dtype=dtype, size=shape) + + with pytest.warns(NumpyConversionWarning): + synced_collection.append(value) + raw_value = value.item() if shape is None else value.tolist() + assert synced_collection[-1] == raw_value + + # Test assignment after append. + with pytest.warns(NumpyConversionWarning): + synced_collection[-1] = value + + @pytest.mark.skipif(not NUMPY, reason="This test requires the numpy package.") + @pytest.mark.parametrize("dtype", NUMPY_FLOAT_TYPES) + @pytest.mark.parametrize("shape", NUMPY_SHAPES) + def test_set_get_numpy_float_data(self, synced_collection, dtype, shape): + """Test setting scalar float types, which work if a raw Python analog exists.""" + value = dtype(random_sample(shape)) + + # If casting via item does not give a base Python type, the number + # should fail to set correctly. + raw_value = value.item() if shape is None else value.tolist() + test_value = value.item(0) if isinstance(raw_value, list) else raw_value + should_fail = isinstance(test_value, (numpy.number, numpy.bool_)) + + if should_fail: + with pytest.raises((ValueError, TypeError)), pytest.warns( + NumpyConversionWarning + ): + synced_collection.append(value) + else: + with pytest.warns(NumpyConversionWarning): + synced_collection.append(value) + assert synced_collection[-1] == raw_value + + # Test assignment after append. + with pytest.warns(NumpyConversionWarning): + synced_collection[-1] = value + + @pytest.mark.skipif(not NUMPY, reason="This test requires the numpy package.") + @pytest.mark.parametrize("dtype", NUMPY_COMPLEX_TYPES) + @pytest.mark.parametrize("shape", NUMPY_SHAPES) + def test_set_get_numpy_complex_data(self, synced_collection, dtype, shape): + """Test setting scalar complex types, which should always fail.""" + # Note that the current behavior of this test is based on the fact that + # all backends rely on JSON-serialization (at least implicitly), even + # non-JSON backends. This test may have to be generalized if we add any + # backends that support other data, or if we want to test cases like + # ZarrCollection with a non-JSON codec (alternatives are supported, but + # not a priority to test here). + value = dtype(random_sample(shape)) + + with pytest.raises((ValueError, TypeError)), pytest.warns( + NumpyConversionWarning + ): + synced_collection.append(value) + + with pytest.raises((ValueError, TypeError)), pytest.warns( + NumpyConversionWarning + ): + synced_collection[-1] = value + + @pytest.mark.parametrize("dtype", NUMPY_INT_TYPES) + @pytest.mark.parametrize("shape", NUMPY_SHAPES) + def test_reset_numpy_int_data(self, synced_collection, dtype, shape): + """Test setting scalar int types, which should always work.""" + try: + max_value = numpy.iinfo(dtype).max + except ValueError: + max_value = 1 + value = randint(max_value, dtype=dtype, size=shape) + + if shape is None: + with pytest.raises((ValueError, TypeError)), pytest.warns( + NumpyConversionWarning + ): + synced_collection.reset(value) + else: + with pytest.warns(NumpyConversionWarning): + synced_collection.reset(value) + assert synced_collection == value.tolist() diff --git a/tests/test_synced_collections/test_json_buffered_collection.py b/tests/test_synced_collections/test_json_buffered_collection.py new file mode 100644 index 000000000..b75f4afeb --- /dev/null +++ b/tests/test_synced_collections/test_json_buffered_collection.py @@ -0,0 +1,782 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +import itertools +import json +import os +import time +from concurrent.futures import ThreadPoolExecutor + +import pytest +from attr_dict_test import AttrDictTest, AttrListTest +from test_json_collection import JSONCollectionTest, TestJSONDict, TestJSONList + +from signac.synced_collections.backends.collection_json import ( + BufferedJSONAttrDict, + BufferedJSONAttrList, + BufferedJSONDict, + BufferedJSONList, + MemoryBufferedJSONAttrDict, + MemoryBufferedJSONAttrList, + MemoryBufferedJSONDict, + MemoryBufferedJSONList, +) +from signac.synced_collections.errors import BufferedError, MetadataError + + +class BufferedJSONCollectionTest(JSONCollectionTest): + def load(self, collection): + """Load the data corresponding to a SyncedCollection from disk.""" + with open(collection.filename) as f: + return json.load(f) + + +class TestBufferedJSONDict(BufferedJSONCollectionTest, TestJSONDict): + """Tests of buffering JSONDicts.""" + + _collection_type = BufferedJSONDict # type: ignore + + @pytest.fixture + def synced_collection2(self, tmpdir): + yield self._collection_type( + filename=os.path.join(tmpdir, "test2.json"), write_concern=False + ) + + def test_buffered(self, synced_collection, testdata): + """Test basic per-instance buffering behavior.""" + assert len(synced_collection) == 0 + synced_collection["buffered"] = testdata + assert "buffered" in synced_collection + assert synced_collection["buffered"] == testdata + with synced_collection.buffered: + assert "buffered" in synced_collection + assert synced_collection["buffered"] == testdata + synced_collection["buffered2"] = 1 + assert "buffered2" in synced_collection + assert synced_collection["buffered2"] == 1 + assert len(synced_collection) == 2 + assert "buffered2" in synced_collection + assert synced_collection["buffered2"] == 1 + with synced_collection.buffered: + del synced_collection["buffered"] + assert len(synced_collection) == 1 + assert "buffered" not in synced_collection + assert len(synced_collection) == 1 + assert "buffered" not in synced_collection + assert "buffered2" in synced_collection + assert synced_collection["buffered2"] == 1 + + # Explicitly check that the file has not been changed when buffering. + raw_dict = synced_collection() + with synced_collection.buffered: + synced_collection["buffered3"] = 1 + on_disk_dict = self.load(synced_collection) + assert "buffered3" not in on_disk_dict + assert on_disk_dict == raw_dict + + on_disk_dict = self.load(synced_collection) + assert "buffered3" in on_disk_dict + assert on_disk_dict == synced_collection + + def test_two_buffered(self, synced_collection, testdata): + """Test that a non-buffered copy is not modified.""" + synced_collection["buffered"] = testdata + synced_collection2 = self._collection_type(filename=synced_collection._filename) + + # Check that the non-buffered object is not modified. + with synced_collection.buffered: + synced_collection["buffered2"] = 1 + assert "buffered2" not in synced_collection2 + + def test_two_buffered_modify_unbuffered(self, synced_collection, testdata): + """Test that in-memory changes raise errors in buffered mode.""" + synced_collection["buffered"] = testdata + synced_collection2 = self._collection_type(filename=synced_collection._filename) + + # Check that the non-buffered object is not modified. + with pytest.raises(MetadataError): + with synced_collection.buffered: + synced_collection["buffered2"] = 1 + synced_collection2["buffered2"] = 2 + assert synced_collection["buffered2"] == 1 + synced_collection["buffered2"] = 3 + assert synced_collection2["buffered2"] == 2 + synced_collection2["buffered2"] = 3 + assert synced_collection["buffered2"] == 3 + + def test_two_buffered_modify_unbuffered_first(self, synced_collection, testdata): + synced_collection["buffered"] = testdata + synced_collection2 = self._collection_type(filename=synced_collection._filename) + + # Check that the non-buffered object is not modified. + with synced_collection.buffered: + synced_collection2["buffered2"] = 1 + assert "buffered2" in synced_collection + synced_collection["buffered2"] = 3 + assert synced_collection == {"buffered": testdata, "buffered2": 3} + + def test_global_buffered(self, synced_collection, testdata): + assert len(synced_collection) == 0 + synced_collection["buffered"] = testdata + assert "buffered" in synced_collection + assert synced_collection["buffered"] == testdata + with self._collection_type.buffer_backend(): + assert "buffered" in synced_collection + assert synced_collection["buffered"] == testdata + synced_collection["buffered2"] = 1 + assert "buffered2" in synced_collection + assert synced_collection["buffered2"] == 1 + assert len(synced_collection) == 2 + assert "buffered2" in synced_collection + assert synced_collection["buffered2"] == 1 + with self._collection_type.buffer_backend(): + del synced_collection["buffered"] + assert len(synced_collection) == 1 + assert "buffered" not in synced_collection + assert len(synced_collection) == 1 + assert "buffered" not in synced_collection + assert "buffered2" in synced_collection + assert synced_collection["buffered2"] == 1 + + with pytest.raises(BufferedError): + with self._collection_type.buffer_backend(): + synced_collection["buffered2"] = 2 + self.store(synced_collection, {"test": 1}) + assert synced_collection["buffered2"] == 2 + assert "test" in synced_collection + assert synced_collection["test"] == 1 + + def test_nested_same_collection(self, synced_collection): + """Test nesting global buffering.""" + assert len(synced_collection) == 0 + + for outer_buffer, inner_buffer in itertools.product( + [synced_collection.buffered, self._collection_type.buffer_backend()], + repeat=2, + ): + err_msg = ( + f"outer_buffer: {type(outer_buffer).__qualname__}, " + f"inner_buffer: {type(inner_buffer).__qualname__}" + ) + synced_collection.reset({"outside": 1}) + with outer_buffer(): + synced_collection["inside_first"] = 2 + with inner_buffer(): + synced_collection["inside_second"] = 3 + + on_disk_dict = self.load(synced_collection) + assert "inside_first" not in on_disk_dict, err_msg + assert "inside_second" not in on_disk_dict, err_msg + assert "inside_first" in synced_collection, err_msg + assert "inside_second" in synced_collection, err_msg + + assert self.load(synced_collection) == synced_collection + + def test_nested_different_collections(self, synced_collection, synced_collection2): + """Test nested buffering for different collections.""" + assert len(synced_collection) == 0 + assert len(synced_collection2) == 0 + + synced_collection["outside"] = 1 + synced_collection2["outside"] = 1 + with synced_collection.buffered: + synced_collection["inside_first"] = 2 + on_disk_dict = self.load(synced_collection) + assert "inside_first" in synced_collection + assert "inside_first" not in on_disk_dict + + synced_collection2["inside_first"] = 2 + on_disk_dict2 = self.load(synced_collection2) + assert "inside_first" in synced_collection2 + assert "inside_first" in on_disk_dict2 + + with self._collection_type.buffer_backend(): + synced_collection["inside_second"] = 3 + synced_collection2["inside_second"] = 3 + + on_disk_dict = self.load(synced_collection) + assert "inside_second" in synced_collection + assert "inside_second" not in on_disk_dict + on_disk_dict2 = self.load(synced_collection2) + assert "inside_second" in synced_collection2 + assert "inside_second" not in on_disk_dict2 + + on_disk_dict = self.load(synced_collection) + on_disk_dict2 = self.load(synced_collection2) + + assert "inside_first" in synced_collection + assert "inside_first" not in on_disk_dict + + assert "inside_second" in synced_collection + assert "inside_second" not in on_disk_dict + assert "inside_second" in synced_collection2 + assert "inside_second" in on_disk_dict2 + + on_disk_dict = self.load(synced_collection) + on_disk_dict2 = self.load(synced_collection2) + + assert "inside_first" in synced_collection + assert "inside_first" in on_disk_dict + + assert "inside_second" in synced_collection + assert "inside_second" in on_disk_dict + assert "inside_second" in synced_collection2 + assert "inside_second" in on_disk_dict2 + + def test_nested_copied_collection(self, synced_collection): + """Test modifying two collections pointing to the same data.""" + synced_collection2 = self._collection_type(filename=synced_collection._filename) + + assert len(synced_collection) == 0 + assert len(synced_collection2) == 0 + + synced_collection["outside"] = 1 + with synced_collection.buffered: + synced_collection["inside_first"] = 2 + + on_disk_dict = self.load(synced_collection) + assert synced_collection["inside_first"] == 2 + assert "inside_first" not in on_disk_dict + + with self._collection_type.buffer_backend(): + synced_collection["inside_second"] = 3 + synced_collection2["inside_second"] = 4 + + on_disk_dict = self.load(synced_collection) + assert synced_collection["inside_second"] == 4 + assert synced_collection2["inside_second"] == 4 + assert "inside_second" not in on_disk_dict + + on_disk_dict = self.load(synced_collection) + assert on_disk_dict["inside_second"] == 4 + + @pytest.mark.skip("This is an example of unsupported (and undefined) behavior).") + def test_nested_copied_collection_invalid(self, synced_collection): + """Test the behavior of invalid modifications of copied objects.""" + synced_collection2 = self._collection_type(filename=synced_collection._filename) + + assert len(synced_collection) == 0 + assert len(synced_collection2) == 0 + + synced_collection["outside"] = 1 + finished = False + with pytest.raises(MetadataError): + with synced_collection.buffered: + synced_collection["inside_first"] = 2 + # Modifying synced_collection2 here causes problems. It is + # unbuffered, so it directly writes to file. Then, when + # entering global buffering in the context below, + # synced_collection2 sees that synced_collection has already + # saved data for this file to the buffer, so it loads that + # data, which also means that synced_collection2 becomes + # associated with the metadata stored when synced_collection + # entered buffered mode. As a result, when the global buffering + # exits, we see metadata errors because synced_collection2 lost + # track of the fact that it saved changes to the file made + # prior to entering the global buffer. While this case could be + # given some specific behavior, there's no obvious canonical + # source of truth here, so we simply choose to skip it + # altogether. + synced_collection2["inside_first"] = 3 + + on_disk_dict = self.load(synced_collection) + assert synced_collection["inside_first"] == 2 + assert on_disk_dict["inside_first"] == 3 + + with self._collection_type.buffer_backend(): + synced_collection["inside_second"] = 3 + synced_collection2["inside_second"] = 4 + + on_disk_dict = self.load(synced_collection) + assert synced_collection["inside_second"] == 4 + assert synced_collection2["inside_second"] == 4 + assert "inside_second" not in on_disk_dict + + on_disk_dict = self.load(synced_collection) + assert on_disk_dict["inside_second"] == 4 + # Check that all the checks ran before the assertion failure. + finished = True + assert finished + + def test_buffer_flush(self, synced_collection): + """Test that the buffer gets flushed when enough data is written.""" + original_buffer_capacity = self._collection_type.get_buffer_capacity() + assert self._collection_type.get_current_buffer_size() == 0 + + try: + self._collection_type.set_buffer_capacity(20) + + # Ensure that the file exists on disk by executing a clear operation so + # that load operations work as expected. + assert len(synced_collection) == 0 + synced_collection.clear() + + with synced_collection.buffered: + synced_collection["foo"] = 1 + assert self._collection_type.get_current_buffer_size() == len( + repr(synced_collection) + ) + assert synced_collection != self.load(synced_collection) + + # Add a long enough value to force a flush. + synced_collection["bar"] = 100 + assert self._collection_type.get_current_buffer_size() == 0 + + # Make sure the file on disk now matches. + assert synced_collection == self.load(synced_collection) + finally: + # Reset buffer capacity for other tests. + self._collection_type.set_buffer_capacity(original_buffer_capacity) + + # To avoid confusing test failures later, make sure the buffer is + # truly flushed correctly. + assert self._collection_type.get_current_buffer_size() == 0 + + def multithreaded_buffering_test(self, op, tmpdir): + """Test that buffering in a multithreaded context is safe for different operations. + + This method encodes the logic for the test, but can be used to test different + operations on the dict. + """ + original_buffer_capacity = self._collection_type.get_buffer_capacity() + try: + # Choose some arbitrarily low value that will ensure intermittent + # forced buffer flushes. + new_buffer_capacity = 20 + self._collection_type.set_buffer_capacity(new_buffer_capacity) + + with self._collection_type.buffer_backend(): + num_dicts = 100 + dicts = [] + dict_data = [] + for i in range(num_dicts): + fn = os.path.join(tmpdir, f"test_dict{i}.json") + dicts.append(self._collection_type(filename=fn)) + dict_data.append({str(j): j for j in range(i)}) + + num_threads = 10 + try: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(op, dicts, dict_data)) + except KeyError as e: + raise RuntimeError( + "Buffering in parallel failed due to different threads " + "simultaneously modifying the buffer." + ) from e + + # First validate inside buffer. + assert all(dicts[i] == dict_data[i] for i in range(num_dicts)) + # Now validate outside buffer. + assert all(dicts[i] == dict_data[i] for i in range(num_dicts)) + finally: + # Reset buffer capacity for other tests in case this fails. + self._collection_type.set_buffer_capacity(original_buffer_capacity) + + # To avoid confusing test failures later, make sure the buffer is + # truly flushed correctly. + assert self._collection_type.get_current_buffer_size() == 0 + + def test_multithreaded_buffering_setitem(self, tmpdir): + """Test setitem in a multithreaded buffering context.""" + + def setitem_dict(sd, data): + for k, v in data.items(): + sd[k] = v + + self.multithreaded_buffering_test(setitem_dict, tmpdir) + + def test_multithreaded_buffering_update(self, tmpdir): + """Test update in a multithreaded buffering context.""" + + def update_dict(sd, data): + sd.update(data) + + self.multithreaded_buffering_test(update_dict, tmpdir) + + def test_multithreaded_buffering_reset(self, tmpdir): + """Test reset in a multithreaded buffering context.""" + + def reset_dict(sd, data): + sd.reset(data) + + self.multithreaded_buffering_test(reset_dict, tmpdir) + + def test_multithreaded_buffering_clear(self, tmpdir): + """Test clear in a multithreaded buffering context. + + Since clear requires ending up with an empty dict, it's easier to + write a separate test from the others. + """ + original_buffer_capacity = self._collection_type.get_buffer_capacity() + try: + # Choose some arbitrarily low value that will ensure intermittent + # forced buffer flushes. + new_buffer_capacity = 20 + self._collection_type.set_buffer_capacity(new_buffer_capacity) + + # Initialize the data outside the buffered context so that it's + # already present on disk for testing both. + num_dicts = 100 + dicts = [] + for i in range(num_dicts): + fn = os.path.join(tmpdir, f"test_dict{i}.json") + dicts.append(self._collection_type(filename=fn)) + dicts[-1].update({str(j): j for j in range(i)}) + + with self._collection_type.buffer_backend(): + num_threads = 10 + try: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(lambda sd: sd.clear(), dicts)) + except KeyError as e: + raise RuntimeError( + "Buffering in parallel failed due to different threads " + "simultaneously modifying the buffer." + ) from e + + # First validate inside buffer. + assert all(not dicts[i] for i in range(num_dicts)) + # Now validate outside buffer. + assert all(not dicts[i] for i in range(num_dicts)) + finally: + # Reset buffer capacity for other tests in case this fails. + self._collection_type.set_buffer_capacity(original_buffer_capacity) + + # To avoid confusing test failures later, make sure the buffer is + # truly flushed correctly. + assert self._collection_type.get_current_buffer_size() == 0 + + def test_multithreaded_buffering_load(self, tmpdir): + """Test loading data in a multithreaded buffering context. + + This test is primarily for verifying that multithreaded buffering does + not lead to concurrency errors in flushing data from the buffer due to + too many loads. This test is primarily for buffering methods with a maximum + capacity, even for read-only operations. + """ + original_buffer_capacity = self._collection_type.get_buffer_capacity() + try: + # Choose some arbitrarily low value that will ensure intermittent + # forced buffer flushes. + new_buffer_capacity = 1000 + self._collection_type.set_buffer_capacity(new_buffer_capacity) + + # Must initialize the data outside the buffered context so that + # we only execute read operations inside the buffered context. + num_dicts = 100 + dicts = [] + for i in range(num_dicts): + fn = os.path.join(tmpdir, f"test_dict{i}.json") + dicts.append(self._collection_type(filename=fn)) + # Go to i+1 so that every dict contains the 0 element. + dicts[-1].update({str(j): j for j in range(i + 1)}) + + with self._collection_type.buffer_backend(): + num_threads = 100 + try: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(lambda sd: sd["0"], dicts * 5)) + except KeyError as e: + raise RuntimeError( + "Buffering in parallel failed due to different threads " + "simultaneously modifying the buffer." + ) from e + + finally: + # Reset buffer capacity for other tests in case this fails. + self._collection_type.set_buffer_capacity(original_buffer_capacity) + + # To avoid confusing test failures later, make sure the buffer is + # truly flushed correctly. + assert self._collection_type.get_current_buffer_size() == 0 + + def test_buffer_first_load(self, synced_collection): + """Ensure that existing data is preserved if the first load is in buffered mode.""" + fn = synced_collection.filename + write_concern = self._write_concern + + sc = self._collection_type(fn, write_concern) + sc["foo"] = 1 + sc["bar"] = 2 + del sc + + sc = self._collection_type(fn, write_concern) + with sc.buffered: + sc["foo"] = 3 + + assert "bar" in sc + + def test_data_type_with_buffered(self, synced_collection): + """Make sure that the _data attribute has the right types after buffering.""" + with self._collection_type.buffer_backend(): + synced_collection["foo"] = {"bar": "baz"} + + # This will fail if synced_collection['foo'] is a raw dict. + synced_collection["foo"]() + + +class TestBufferedJSONList(BufferedJSONCollectionTest, TestJSONList): + """Tests of buffering JSONLists.""" + + _collection_type = BufferedJSONList # type: ignore + + def test_buffered(self, synced_collection): + synced_collection.extend([1, 2, 3]) + assert len(synced_collection) == 3 + assert synced_collection == [1, 2, 3] + with synced_collection.buffered: + assert len(synced_collection) == 3 + assert synced_collection == [1, 2, 3] + synced_collection[0] = 4 + assert len(synced_collection) == 3 + assert synced_collection == [4, 2, 3] + assert len(synced_collection) == 3 + assert synced_collection == [4, 2, 3] + with synced_collection.buffered: + assert len(synced_collection) == 3 + assert synced_collection == [4, 2, 3] + del synced_collection[0] + assert len(synced_collection) == 2 + assert synced_collection == [2, 3] + assert len(synced_collection) == 2 + assert synced_collection == [2, 3] + + # Explicitly check that the file has not been changed when buffering. + raw_list = synced_collection() + with synced_collection.buffered: + synced_collection.append(10) + on_disk_list = self.load(synced_collection) + assert 10 not in on_disk_list + assert on_disk_list == raw_list + + on_disk_list = self.load(synced_collection) + assert 10 in on_disk_list + assert on_disk_list == synced_collection + + def test_global_buffered(self, synced_collection): + assert len(synced_collection) == 0 + with self._collection_type.buffer_backend(): + synced_collection.reset([1, 2, 3]) + assert len(synced_collection) == 3 + assert len(synced_collection) == 3 + assert synced_collection == [1, 2, 3] + with self._collection_type.buffer_backend(): + assert len(synced_collection) == 3 + assert synced_collection == [1, 2, 3] + synced_collection[0] = 4 + assert len(synced_collection) == 3 + assert synced_collection == [4, 2, 3] + assert len(synced_collection) == 3 + assert synced_collection == [4, 2, 3] + + # metacheck failure + with pytest.raises(BufferedError): + with self._collection_type.buffer_backend(): + synced_collection.reset([1]) + assert synced_collection == [1] + # Unfortunately the resolution of os.stat is + # platform dependent and may not always be + # high enough for our check to work. Since + # this unit test is artificially simple we + # must add some amount of minimum waiting time + # to ensure that the change in time will be + # detected. + time.sleep(0.01) + self.store(synced_collection, [1, 2, 3]) + assert synced_collection == [1] + assert len(synced_collection) == 3 + assert synced_collection == [1, 2, 3] + + def multithreaded_buffering_test(self, op, requires_init, tmpdir): + """Test that buffering in a multithreaded context is safe for different operations. + + This method encodes the logic for the test, but can be used to test different + operations on the list. + """ + original_buffer_capacity = self._collection_type.get_buffer_capacity() + try: + # Choose some arbitrarily low value that will ensure intermittent + # forced buffer flushes. + new_buffer_capacity = 20 + self._collection_type.set_buffer_capacity(new_buffer_capacity) + + num_lists = 100 + lists = [] + list_data = [] + for i in range(num_lists): + # Initialize data with zeros, but prepare other data for + # updating in place. + fn = os.path.join(tmpdir, f"test_list{i}.json") + lists.append(self._collection_type(filename=fn)) + list_data.append([j for j in range(i)]) + if requires_init: + lists[-1].extend([0 for j in range(i)]) + + with self._collection_type.buffer_backend(): + num_threads = 10 + try: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(op, lists, list_data)) + except KeyError as e: + raise RuntimeError( + "Buffering in parallel failed due to different threads " + "simultaneously modifying the buffer." + ) from e + + # First validate inside buffer. + assert all(lists[i] == list_data[i] for i in range(num_lists)) + # Now validate outside buffer. + assert all(lists[i] == list_data[i] for i in range(num_lists)) + finally: + # Reset buffer capacity for other tests in case this fails. + self._collection_type.set_buffer_capacity(original_buffer_capacity) + + # To avoid confusing test failures later, make sure the buffer is + # truly flushed correctly. + assert self._collection_type.get_current_buffer_size() == 0 + + def test_multithreaded_buffering_setitem(self, tmpdir): + """Test setitem in a multithreaded buffering context.""" + + def setitem_list(sd, data): + for i, val in enumerate(data): + sd[i] = val + + self.multithreaded_buffering_test(setitem_list, True, tmpdir) + + def test_multithreaded_buffering_extend(self, tmpdir): + """Test extend in a multithreaded buffering context.""" + + def extend_list(sd, data): + sd.extend(data) + + self.multithreaded_buffering_test(extend_list, False, tmpdir) + + def test_multithreaded_buffering_append(self, tmpdir): + """Test append in a multithreaded buffering context.""" + + def append_list(sd, data): + for val in data: + sd.append(val) + + self.multithreaded_buffering_test(append_list, False, tmpdir) + + def test_multithreaded_buffering_load(self, tmpdir): + """Test loading data in a multithreaded buffering context. + + This test is primarily for verifying that multithreaded buffering does + not lead to concurrency errors in flushing data from the buffer due to + too many loads. This test is primarily for buffering methods with a maximum + capacity, even for read-only operations. + """ + original_buffer_capacity = self._collection_type.get_buffer_capacity() + try: + # Choose some arbitrarily low value that will ensure intermittent + # forced buffer flushes. + new_buffer_capacity = 1000 + self._collection_type.set_buffer_capacity(new_buffer_capacity) + + # Must initialize the data outside the buffered context so that + # we only execute read operations inside the buffered context. + num_lists = 100 + lists = [] + for i in range(num_lists): + fn = os.path.join(tmpdir, f"test_list{i}.json") + lists.append(self._collection_type(filename=fn)) + # Go to i+1 so that every list contains the 0 element. + lists[-1].extend([j for j in range(i + 1)]) + + with self._collection_type.buffer_backend(): + num_threads = 100 + try: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + list(executor.map(lambda sd: sd[0], lists * 5)) + except KeyError as e: + raise RuntimeError( + "Buffering in parallel failed due to different threads " + "simultaneously modifying the buffer." + ) from e + finally: + # Reset buffer capacity for other tests in case this fails. + self._collection_type.set_buffer_capacity(original_buffer_capacity) + + # To avoid confusing test failures later, make sure the buffer is + # truly flushed correctly. + assert self._collection_type.get_current_buffer_size() == 0 + + +class TestMemoryBufferedJSONDict(TestBufferedJSONDict): + """Tests of MemoryBufferedJSONDicts.""" + + _collection_type = MemoryBufferedJSONDict # type: ignore + + def test_buffer_flush(self, synced_collection, synced_collection2): + """Test that the buffer gets flushed when enough data is written.""" + original_buffer_capacity = self._collection_type.get_buffer_capacity() + + assert self._collection_type.get_current_buffer_size() == 0 + self._collection_type.set_buffer_capacity(1) + + # Ensure that the file exists on disk by executing a clear operation so + # that load operations work as expected. + assert len(synced_collection) == 0 + assert len(synced_collection2) == 0 + synced_collection.clear() + synced_collection2.clear() + + with self._collection_type.buffer_backend(): + synced_collection["foo"] = 1 + assert self._collection_type.get_current_buffer_size() == 1 + assert synced_collection != self.load(synced_collection) + + # This buffering mode is based on the number of files buffered, so + # we need to write to the second collection. + assert "bar" not in synced_collection2 + + # Simply loading the second collection into memory shouldn't + # trigger a flush, because it hasn't been modified and we flush + # based on the total number of modifications. + assert synced_collection != self.load(synced_collection) + + # Modifying the second collection should exceed buffer capacity and + # trigger a flush. + synced_collection2["bar"] = 2 + assert synced_collection == self.load(synced_collection) + assert synced_collection2 == self.load(synced_collection2) + + # Reset buffer capacity for other tests. + self._collection_type.set_buffer_capacity(original_buffer_capacity) + + +class TestMemoryBufferedJSONList(TestBufferedJSONList): + """Tests of MemoryBufferedJSONLists.""" + + _collection_type = MemoryBufferedJSONList # type: ignore + + +class TestBufferedJSONDictWriteConcern(TestBufferedJSONDict): + _write_concern = True + + +class TestBufferedJSONListWriteConcern(TestBufferedJSONList): + _write_concern = True + + +class TestBufferedJSONAttrDict(TestBufferedJSONDict, AttrDictTest): + + _collection_type = BufferedJSONAttrDict # type: ignore + + +class TestBufferedJSONAttrList(TestBufferedJSONList, AttrListTest): + + _collection_type = BufferedJSONAttrList # type: ignore + + +class TestMemoryBufferedJSONAttrDict(TestMemoryBufferedJSONDict, AttrDictTest): + + _collection_type = MemoryBufferedJSONAttrDict # type: ignore + + +class TestMemoryBufferedJSONAttrList(TestMemoryBufferedJSONList, AttrListTest): + + _collection_type = MemoryBufferedJSONAttrList # type: ignore diff --git a/tests/test_synced_collections/test_json_collection.py b/tests/test_synced_collections/test_json_collection.py new file mode 100644 index 000000000..4868afb7b --- /dev/null +++ b/tests/test_synced_collections/test_json_collection.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +import json +import os + +import pytest +from attr_dict_test import AttrDictTest, AttrListTest +from synced_collection_test import SyncedDictTest, SyncedListTest + +from signac.synced_collections.backends.collection_json import ( + JSONAttrDict, + JSONAttrList, + JSONDict, + JSONList, +) + + +class JSONCollectionTest: + + _write_concern = False + _fn = "test.json" + + def store(self, synced_collection, data): + with open(synced_collection.filename, "wb") as f: + f.write(json.dumps(data).encode()) + + @pytest.fixture + def synced_collection(self, tmpdir): + yield self._collection_type( + filename=os.path.join(tmpdir, self._fn), + write_concern=self._write_concern, + ) + + @pytest.fixture + def synced_collection_positional(self, tmpdir): + """Fixture that initializes the object using positional arguments.""" + yield self._collection_type( + os.path.join(tmpdir, "test2.json"), self._write_concern + ) + + def test_filename(self, synced_collection): + assert os.path.basename(synced_collection.filename) == self._fn + + +class TestJSONDict(JSONCollectionTest, SyncedDictTest): + + _collection_type = JSONDict + + # The following test tests the support for non-str keys + # for JSON backend which will be removed in version 2.0. + # See issue: https://github.com/glotzerlab/signac/issues/316. + def test_keys_non_str_valid_type(self, synced_collection, testdata): + for key in (0, None, True): + with pytest.deprecated_call(match="Use of.+as key is deprecated"): + synced_collection[key] = testdata + assert str(key) in synced_collection + assert synced_collection[str(key)] == testdata + + +class TestJSONList(JSONCollectionTest, SyncedListTest): + + _collection_type = JSONList + + +class TestJSONDictWriteConcern(TestJSONDict): + _write_concern = True + + +class TestJSONListWriteConcern(TestJSONList): + _write_concern = True + + +class TestJSONAttrDict(TestJSONDict, AttrDictTest): + + _collection_type = JSONAttrDict + + +class TestJSONAttrList(TestJSONList, AttrListTest): + + _collection_type = JSONAttrList diff --git a/tests/test_synced_collections/test_mongodb_collection.py b/tests/test_synced_collections/test_mongodb_collection.py new file mode 100644 index 000000000..5c24a4ac1 --- /dev/null +++ b/tests/test_synced_collections/test_mongodb_collection.py @@ -0,0 +1,103 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +import pytest +from synced_collection_test import SyncedDictTest, SyncedListTest + +from signac.synced_collections.backends.collection_mongodb import ( + MongoDBDict, + MongoDBList, +) + +try: + import pymongo + + try: + # Test the mongodb server. Set a short timeout so that tests don't + # appear to hang while waiting for a connection. + mongo_client = pymongo.MongoClient(serverSelectionTimeoutMS=1000) + tmp_collection = mongo_client["test_db"]["test"] + tmp_collection.insert_one({"test": "0"}) + ret = tmp_collection.find_one({"test": "0"}) + assert ret["test"] == "0" + tmp_collection.drop() + PYMONGO = True + except (pymongo.errors.ServerSelectionTimeoutError, AssertionError): + PYMONGO = False +except ImportError: + PYMONGO = False + + +try: + import numpy + + NUMPY = True + + from synced_collection_test import NUMPY_INT_TYPES, NUMPY_SHAPES + + # BSON does not support >8-byte ints. We remove larger types since some are + # architecture-dependent. + NUMPY_INT_TYPES = tuple( + [ + dtype + for dtype in NUMPY_INT_TYPES + if issubclass(dtype, numpy.number) + and numpy.log2(numpy.iinfo(dtype).max) / 8 < 8 + ] + ) +except ImportError: + NUMPY = False + + NUMPY_INT_TYPES = () + NUMPY_SHAPES = () + + +class MongoDBCollectionTest: + + _uid = {"MongoDBCollection::name": "test"} + + def store(self, synced_collection, data): + data_to_insert = {**synced_collection.uid, "data": data} + synced_collection.collection.replace_one(synced_collection.uid, data_to_insert) + + @pytest.fixture + def synced_collection(self, request): + yield self._collection_type( + uid=self._uid, collection=mongo_client.test_db.test_dict + ) + mongo_client.test_db.test_dict.drop() + + @pytest.fixture + def synced_collection_positional(self): + """Fixture that initializes the object using positional arguments.""" + yield self._collection_type(mongo_client.test_db.test_dict, self._uid) + mongo_client.test_db.test_dict.drop() + + def test_uid(self, synced_collection): + assert synced_collection.uid == self._uid + + @pytest.mark.parametrize("dtype", NUMPY_INT_TYPES) + @pytest.mark.parametrize("shape", NUMPY_SHAPES) + def test_set_get_numpy_int_data(self, synced_collection, dtype, shape): + """Override parent test to use the subset of int types.""" + super().test_set_get_numpy_int_data(synced_collection, dtype, shape) + + +@pytest.mark.skipif( + not PYMONGO, reason="test requires the pymongo package and mongodb server" +) +class TestMongoDBDict(MongoDBCollectionTest, SyncedDictTest): + _collection_type = MongoDBDict + + +@pytest.mark.skipif( + not PYMONGO, reason="test requires the pymongo package and mongodb server" +) +class TestMongoDBList(MongoDBCollectionTest, SyncedListTest): + _collection_type = MongoDBList + + @pytest.mark.parametrize("dtype", NUMPY_INT_TYPES) + @pytest.mark.parametrize("shape", (None, (1,), (2,))) + def test_reset_numpy_int_data(self, synced_collection, dtype, shape): + """Override parent test to use the subset of int types.""" + super().test_reset_numpy_int_data(synced_collection, dtype, shape) diff --git a/tests/test_synced_collections/test_redis_collection.py b/tests/test_synced_collections/test_redis_collection.py new file mode 100644 index 000000000..26687a072 --- /dev/null +++ b/tests/test_synced_collections/test_redis_collection.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +import json +import uuid + +import pytest +from synced_collection_test import SyncedDictTest, SyncedListTest + +from signac.synced_collections.backends.collection_redis import RedisDict, RedisList + +try: + import redis + + try: + # try to connect to server + redis_client = redis.Redis() + test_key = str(uuid.uuid4()) + redis_client.set(test_key, 0) + assert redis_client.get(test_key) == b"0" # redis stores data as bytes + redis_client.delete(test_key) + REDIS = True + except (redis.exceptions.ConnectionError, AssertionError): + REDIS = False +except ImportError: + REDIS = False + + +class RedisCollectionTest: + + _key = "test" + + def store(self, synced_collection, data): + synced_collection.client.set(synced_collection.key, json.dumps(data).encode()) + + @pytest.fixture + def synced_collection(self, request): + request.addfinalizer(redis_client.flushall) + yield self._collection_type(key=self._key, client=redis_client) + + @pytest.fixture + def synced_collection_positional(self, request): + """Fixture that initializes the object using positional arguments.""" + request.addfinalizer(redis_client.flushall) + yield self._collection_type(redis_client, self._key) + + def test_key(self, synced_collection): + assert synced_collection.key == self._key + + +@pytest.mark.skipif( + not REDIS, reason="test requires the redis package and running redis-server" +) +class TestRedisDict(RedisCollectionTest, SyncedDictTest): + _collection_type = RedisDict + + +@pytest.mark.skipif( + not REDIS, reason="test requires the redis package and running redis-server" +) +class TestRedisList(RedisCollectionTest, SyncedListTest): + _collection_type = RedisList diff --git a/tests/test_synced_collections/test_utils.py b/tests/test_synced_collections/test_utils.py new file mode 100644 index 000000000..3df1774ba --- /dev/null +++ b/tests/test_synced_collections/test_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +import json +import os +from collections.abc import Collection, MutableSequence + +import pytest + +from signac.synced_collections import SyncedList +from signac.synced_collections.backends.collection_json import JSONDict +from signac.synced_collections.numpy_utils import NumpyConversionWarning +from signac.synced_collections.utils import ( + AbstractTypeResolver, + SyncedCollectionJSONEncoder, +) + +try: + import numpy + + NUMPY = True +except ImportError: + NUMPY = False + + +def test_type_resolver(): + resolver = AbstractTypeResolver( + { + "dict": lambda obj: isinstance(obj, dict), + "tuple": lambda obj: isinstance(obj, tuple), + "str": lambda obj: isinstance(obj, str), + "mutablesequence": lambda obj: isinstance(obj, MutableSequence), + "collection": lambda obj: isinstance(obj, Collection), + "set": lambda obj: isinstance(obj, set), + } + ) + + assert resolver.get_type({}) == "dict" + assert resolver.get_type((0, 1)) == "tuple" + assert resolver.get_type("abc") == "str" + assert resolver.get_type([]) == "mutablesequence" + + # Make sure that order matters; collection should be found before list. + assert resolver.get_type(set()) == "collection" + + +def test_json_encoder(tmpdir): + # Raw dictionaries should be encoded transparently. + data = {"foo": 1, "bar": 2, "baz": 3} + json_str_data = '{"foo": 1, "bar": 2, "baz": 3}' + assert json.dumps(data) == json_str_data + assert json.dumps(data, cls=SyncedCollectionJSONEncoder) == json_str_data + assert json.dumps(data, cls=SyncedCollectionJSONEncoder) == json.dumps(data) + + fn = os.path.join(tmpdir, "test_json_encoding.json") + synced_data = JSONDict(fn) + synced_data.update(data) + with pytest.raises(TypeError): + json.dumps(synced_data) + assert json.dumps(synced_data, cls=SyncedCollectionJSONEncoder) == json_str_data + + if NUMPY: + # Test both scalar and array numpy types since they could have + # different problems. + array = numpy.array(3) + with pytest.warns(NumpyConversionWarning): + synced_data["foo"] = array + assert isinstance(synced_data["foo"], int) + + array = numpy.random.rand(3) + with pytest.warns(NumpyConversionWarning): + synced_data["foo"] = array + assert isinstance(synced_data["foo"], SyncedList) + assert ( + json.loads(json.dumps(synced_data, cls=SyncedCollectionJSONEncoder)) + == synced_data() + ) diff --git a/tests/test_synced_collections/test_validators.py b/tests/test_synced_collections/test_validators.py new file mode 100644 index 000000000..07f5ab5fc --- /dev/null +++ b/tests/test_synced_collections/test_validators.py @@ -0,0 +1,114 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +import pytest + +from signac.synced_collections.errors import InvalidKeyError, KeyTypeError +from signac.synced_collections.validators import ( + json_format_validator, + no_dot_in_key, + require_string_key, +) + +try: + import numpy + + NUMPY = True +except ImportError: + NUMPY = False + + +class TestRequireStringKey: + def test_valid_data(self, testdata): + test_dict = {} + + key = "valid_str" + test_dict[key] = testdata + require_string_key(test_dict) + assert key in test_dict + assert test_dict[key] == testdata + + def test_invalid_data(self, testdata): + # invalid key types + for key in (0.0, 1.0 + 2.0j, (1, 2, 3), 1, False, None): + with pytest.raises(KeyTypeError): + require_string_key({key: testdata}) + + +class TestNoDotInKey: + def test_valid_data(self, testdata): + test_dict = {} + # valid data + for key in ("valid_str", 1, False, None): + test_dict[key] = testdata + no_dot_in_key(test_dict) + assert key in test_dict + assert test_dict[key] == testdata + + def test_invalid_data(self, testdata): + # dict key containing dot + with pytest.raises(InvalidKeyError): + no_dot_in_key({"a.b": testdata}) + # nested dict key containing dot + with pytest.raises(InvalidKeyError): + no_dot_in_key({"nested": {"a.b": 1}}) + # list containing dict + with pytest.raises(InvalidKeyError): + no_dot_in_key([{"a.b": 1}]) + # invalid key types + for key in (0.0, 1.0 + 2.0j, (1, 2, 3)): + with pytest.raises(KeyTypeError): + no_dot_in_key({key: testdata}) + + +class TestJSONFormatValidator: + def test_valid_data(self): + for data in ("foo", 1, 1.0, True, None, {}, []): + json_format_validator(data) + json_format_validator({"test_key": data}) + json_format_validator(("foo", 1, 1.0, True, None, {}, [])) + + def test_dict_data(self, testdata): + for data in ("foo", 1, 1.0, True, None): + json_format_validator({"test_key": data}) + for key in (0.0, (1, 2, 3)): + with pytest.raises(KeyTypeError): + json_format_validator({key: testdata}) + + @pytest.mark.skipif(not NUMPY, reason="test requires the numpy package") + def test_numpy_data(self): + data = numpy.random.rand(3, 4) + json_format_validator(data) + json_format_validator(numpy.float_(3.14)) + # numpy data as dict value + json_format_validator({"test": data}) + json_format_validator({"test": numpy.float_(1.0)}) + # numpy data in list + json_format_validator([data, numpy.float_(1.0), 1, "test"]) + + def test_invalid_data(self): + class A: + pass + + invalid_data = (1.0 + 2.0j, A()) + for data in invalid_data: + with pytest.raises(TypeError): + json_format_validator(data) + # invalid data as dict value + for data in invalid_data: + with pytest.raises(TypeError): + json_format_validator({"test": data}) + # invalid data in tuple + with pytest.raises(TypeError): + json_format_validator(invalid_data) + + @pytest.mark.skipif(not NUMPY, reason="test requires the numpy package") + def test_numpy_invalid_data(self): + # complex data + data = numpy.complex(1 + 2j) + with pytest.raises(TypeError): + json_format_validator(data) + # complex data in ndarray + data = numpy.asarray([1, 2, 1j, 1 + 2j]) + with pytest.raises(TypeError): + json_format_validator(data) diff --git a/tests/test_synced_collections/test_zarr_collection.py b/tests/test_synced_collections/test_zarr_collection.py new file mode 100644 index 000000000..0df9c032b --- /dev/null +++ b/tests/test_synced_collections/test_zarr_collection.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +import pytest +from synced_collection_test import SyncedDictTest, SyncedListTest + +from signac.synced_collections.backends.collection_zarr import ZarrDict, ZarrList + +try: + import numcodecs # zarr depends on numcodecs + import zarr + + ZARR = True +except ImportError: + ZARR = False + + +class ZarrCollectionTest: + + _name = "test" + + def store(self, synced_collection, data): + dataset = synced_collection.group.require_dataset( + "test", + overwrite=True, + shape=1, + dtype="object", + object_codec=numcodecs.JSON(), + ) + dataset[0] = data + + @pytest.fixture + def synced_collection(self, tmpdir): + yield self._collection_type( + name=self._name, group=zarr.group(zarr.DirectoryStore(tmpdir)) + ) + + @pytest.fixture + def synced_collection_positional(self, tmpdir): + """Fixture that initializes the object using positional arguments.""" + yield self._collection_type(zarr.group(zarr.DirectoryStore(tmpdir)), self._name) + + def test_name(self, synced_collection): + assert synced_collection.name == self._name + + +@pytest.mark.skipif(not ZARR, reason="test requires the zarr package") +class TestZarrDict(ZarrCollectionTest, SyncedDictTest): + _collection_type = ZarrDict + + +@pytest.mark.skipif(not ZARR, reason="test requires the zarr package") +class TestZarrList(ZarrCollectionTest, SyncedListTest): + _collection_type = ZarrList