Skip to content

Commit

Permalink
More streamlined context caches usage (#547)
Browse files Browse the repository at this point in the history
* Removing context_cache kwargs. Users have to use instance attributes for context caches.
* Adding context cache attributes and examples in docstring.
  • Loading branch information
ijpulidos authored Feb 14, 2022
1 parent f1dc942 commit 7e6926f
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 80 deletions.
7 changes: 6 additions & 1 deletion docs/releasehistory.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
Release History
***************

0.2Y.Z - XXXXXX release
0.21.1 - Bugfix release
=======================

Bugfixes
--------
- More streamlined context cache usage using instance attributes (`#547 <https://github.com/choderalab/openmmtools/pull/547>`_).
- Improved docstring and examples for ``MultiStateSampler`` object.

0.21.0 - Bugfix release
=======================

Expand Down
145 changes: 67 additions & 78 deletions openmmtools/multistate/multistatesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,56 @@ class MultiStateSampler(object):
sampler_states
metadata
is_completed
:param number_of_iterations: Maximum number of integer iterations that will be run
:param online_analysis_interval: How frequently to carry out online analysis in number of iterations
:param online_analysis_target_error: Target free energy difference error float at which simulation will be stopped during online analysis, in dimensionless energy
:param online_analysis_minimum_iterations: Minimum number of iterations needed before online analysis is run as int
energy_context_cache : openmmtools.cache.ContextCache, default=openmmtools.cache.global_context_cache
Context cache to be used for energy computations. Defaults to using global context cache.
sampler_context_cache : openmmtools.cache.ContextCache, default=openmmtools.cache.global_context_cache
Context cache to be used for propagation. Defaults to using global context cache.
Examples
--------
Sampling multiple states of an alanine dipeptide in implicit solvent system.
>>> import math
>>> import tempfile
>>> from openmm import unit
>>> from openmmtools import testsystems, states, mcmc
>>> from openmmtools.multistate import MultiStateSampler, MultiStateReporter
>>> testsystem = testsystems.AlanineDipeptideImplicit()
Create thermodynamic states
>>> n_replicas = 3
>>> T_min = 298.0 * unit.kelvin # Minimum temperature.
>>> T_max = 600.0 * unit.kelvin # Maximum temperature.
>>> temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0)
... for i in range(n_replicas)]
>>> temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0)
... for i in range(n_replicas)]
>>> thermodynamic_states = [states.ThermodynamicState(system=testsystem.system, temperature=T)
... for T in temperatures]
Initialize simulation object with options. Run with a GHMC integrator.
>>> move = mcmc.GHMCMove(timestep=2.0*unit.femtoseconds, n_steps=50)
>>> simulation = MultiStateSampler(mcmc_moves=move, number_of_iterations=2)
Create simulation and store output in temporary file
>>> storage_path = tempfile.NamedTemporaryFile(delete=False).name + '.nc'
>>> reporter = MultiStateReporter(storage_path, checkpoint_interval=1)
>>> simulation.create(thermodynamic_states=thermodynamic_states,
... sampler_states=states.SamplerState(testsystem.positions), storage=reporter)
Optionally, specify unlimited context cache attributes using the fastest mixed precision platform
>>> from openmmtools.cache import ContextCache
>>> from openmmtools.utils import get_fastest_platform
>>> platform = get_fastest_platform(minimum_precision='mixed')
>>> simulation.energy_context_cache = ContextCache(capacity=None, time_to_live=None, platform=platform)
>>> simulation.sampler_context_cache = ContextCache(capacity=None, time_to_live=None, platform=platform)
Run the simulation
>>> simulation.run()
"""

# -------------------------------------------------------------------------
Expand Down Expand Up @@ -194,8 +235,8 @@ def __init__(self, mcmc_moves=None, number_of_iterations=1,

self._have_displayed_citations_before = False

# Initializing context cache attributes to global cache
self.energy_context_cache, self.sampler_context_cache = cache.global_context_cache, cache.global_context_cache
# Initializing context cache attributes
self._initialize_context_caches()

# Check convergence.
if self.number_of_iterations == np.inf:
Expand All @@ -207,7 +248,7 @@ def __init__(self, mcmc_moves=None, number_of_iterations=1,
"specified maximum number of iterations!")

@classmethod
def from_storage(cls, storage, energy_context_cache=None, propagation_context_cache=None):
def from_storage(cls, storage):
"""Constructor from an existing storage file.
Parameters
Expand All @@ -217,12 +258,6 @@ def from_storage(cls, storage, energy_context_cache=None, propagation_context_ca
If :class:`Reporter`: uses the :class:`Reporter` options
In the future this will be able to take a Storage class as well.
energy_context_cache : openmmtools.cache.ContextCache or None, optional, default None
Context cache to be used for energy computations. If None, a new fresh cache will be used.
propagation_context_cache : openmmtools.cache.ContextCache or None, optional, default None
Context cache to be used for move/integrator propagation. If None, a new fresh cache will be used.
Returns
-------
sampler : MultiStateSampler
Expand All @@ -237,17 +272,15 @@ def from_storage(cls, storage, energy_context_cache=None, propagation_context_ca
# Open the reporter to read the data.
reporter.open(mode='r')
sampler = cls._instantiate_sampler_from_reporter(reporter)
sampler._restore_sampler_from_reporter(reporter,
energy_context_cache=energy_context_cache,
propagation_context_cache=propagation_context_cache)
sampler._restore_sampler_from_reporter(reporter)
finally:
# Close reporter in reading mode.
reporter.close()

# We open the reporter only in node 0 in append mode ready for use
sampler._reporter = reporter
mpiplus.run_single_node(0, sampler._reporter.open, mode='a',
broadcast_result=False, sync_nodes=False)
broadcast_result=False, sync_nodes=False)
# Don't write the new last iteration, we have not technically
# written anything yet, so there is no "junk".
return sampler
Expand Down Expand Up @@ -491,7 +524,7 @@ def is_completed(self):

def create(self, thermodynamic_states: list, sampler_states, storage,
initial_thermodynamic_states=None, unsampled_thermodynamic_states=None,
metadata=None, energy_context_cache=None, sampler_context_cache=None):
metadata=None):
"""Create new multistate sampler simulation.
Parameters
Expand Down Expand Up @@ -532,10 +565,6 @@ def create(self, thermodynamic_states: list, sampler_states, storage,
is None).
metadata : dict, optional, default=None
Simulation metadata to be stored in the file.
energy_context_cache : openmmtools.cache.ContextCache or None, optional, default None
Context cache to be used for energy computations. If None, global context cache will be used.
sampler_context_cache : openmmtools.cache.ContextCache or None, optional, default None
Context cache to be used for move/integrator propagation. If None, global context cache will be used.
"""
# Handle case in which storage is a string and not a Reporter object.
self._reporter = self._reporter_from_storage(storage, check_exist=False)
Expand All @@ -556,9 +585,7 @@ def create(self, thermodynamic_states: list, sampler_states, storage,
self._pre_write_create(thermodynamic_states, sampler_states, storage,
initial_thermodynamic_states=initial_thermodynamic_states,
unsampled_thermodynamic_states=unsampled_thermodynamic_states,
metadata=metadata,
energy_context_cache=energy_context_cache,
sampler_context_cache=sampler_context_cache)
metadata=metadata)

# Display papers to be cited.
self._display_citations()
Expand Down Expand Up @@ -769,9 +796,7 @@ def _pre_write_create(self,
storage,
initial_thermodynamic_states=None,
unsampled_thermodynamic_states=None,
metadata=None,
energy_context_cache=None,
sampler_context_cache=None):
metadata=None,):
"""
Internal function which allocates and sets up ALL variables prior to actually using them.
This is helpful to ensure subclasses have all variables created prior to writing them out with
Expand Down Expand Up @@ -808,13 +833,6 @@ def _pre_write_create(self,
metadata['title'] = default_title
self._metadata = metadata

# Handling context cache parameters and attributes
# update context caches attributes handling inputs
self.energy_context_cache, self.sampler_context_cache = self._initialize_context_caches(
energy_context_cache,
sampler_context_cache
)

# Save thermodynamic states. This sets n_replicas.
self._thermodynamic_states = copy.deepcopy(thermodynamic_states)

Expand Down Expand Up @@ -892,7 +910,7 @@ def _instantiate_sampler_from_reporter(cls, reporter):
sampler._display_citations()
return sampler

def _restore_sampler_from_reporter(self, reporter, energy_context_cache=None, propagation_context_cache=None):
def _restore_sampler_from_reporter(self, reporter):
"""
(Re-)initialize the instanced sampler from the reporter. Intended to be called as the second half of a
:func:`from_storage` method after the :class:`MultiStateSampler` has been instanced from disk.
Expand Down Expand Up @@ -978,12 +996,8 @@ def _read_options(check_iteration):
self._last_mbar_f_k = last_mbar_f_k
self._last_err_free_energy = last_err_free_energy

# Handle with context caches as specified
# update context caches attributes handling inputs
self.energy_context_cache, self.sampler_context_cache = self._initialize_context_caches(
energy_context_cache,
propagation_context_cache
)
# Initialize context caches
self._initialize_context_caches()

def _check_nan_energy(self):
"""Checks that energies are finite and abort otherwise.
Expand Down Expand Up @@ -1652,41 +1666,16 @@ def __init__(self, error_message):

raise RestorationError(message)

@staticmethod
def _initialize_context_caches(energy_context_cache=None, propagation_context_cache=None):
def _initialize_context_caches(self):
"""Handle energy and propagation context cache default behavior.
.. note:: As of 03-Feb-22 default behavior is to use the global cache.
Parameters
----------
energy_context_cache : openmmtools.cache.ContextCache or None
Context cache to be used in energy computations. If None,
it will use the global context cache.
propagation_context_cache : openmmtools.cache.ContextCache or None
Context cache to be used in the propagation of the mcmc moves. If None,
it will use the global context cache.
Centralized API point where to initialize the context cache instance attributes.
Returns
-------
energy_context_cache : openmmtools.cache.ContextCache
Context cache to be used in energy computations.
propagation_context_cache : openmmtools.cache.ContextCache
Context cache to be used in the propagation of the mcmc moves.
.. note:: As of 03-Feb-22 default behavior is to use the global cache.
"""
# Handling energy context cache
if energy_context_cache is None:
# Default behavior, global context cache
energy_context_cache = cache.global_context_cache
elif not isinstance(energy_context_cache, cache.ContextCache):
raise ValueError("Energy context cache input is not a valid ContextCache or None type.")
# Handling propagation context cache
if propagation_context_cache is None:
# Default behavior, global context cache
propagation_context_cache = cache.global_context_cache
elif not isinstance(propagation_context_cache, cache.ContextCache):
raise ValueError("MCMC move context cache input is not a valid ContextCache or None type.")
return energy_context_cache, propagation_context_cache
# Default is using global context cache
self.energy_context_cache = cache.global_context_cache
self.sampler_context_cache = cache.global_context_cache

# -------------------------------------------------------------------------
# Internal-usage: Test globals
Expand Down
2 changes: 1 addition & 1 deletion openmmtools/multistate/paralleltempering.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _compute_replica_energies(self, replica_id):
reference_thermodynamic_state = self._thermodynamic_states[0]

# Get the context, any Integrator works.
context, integrator = cache.global_context_cache.get_context(reference_thermodynamic_state)
context, integrator = self.energy_context_cache.get_context(reference_thermodynamic_state)

# Update positions and box vectors.
sampler_state.apply_to_context(context)
Expand Down
43 changes: 43 additions & 0 deletions openmmtools/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import mpiplus

import openmmtools as mmtools
from openmmtools import cache
from openmmtools import testsystems
from openmmtools.multistate import MultiStateReporter
from openmmtools.multistate import MultiStateSampler, MultiStateSamplerAnalyzer
Expand Down Expand Up @@ -1442,6 +1443,48 @@ def test_online_analysis_stops(self):
assert sampler._iteration < n_iterations
assert sampler.is_completed

def test_context_cache_default(self):
"""Test default behavior of context cache attributes."""
sampler = self.SAMPLER()
global_context_cache = cache.global_context_cache
# Default is to use global context cache for both context cache attributes
assert sampler.sampler_context_cache is global_context_cache
assert sampler.energy_context_cache is global_context_cache

def test_context_cache_energy_propagation(self):
"""Test specifying different context caches for energy and propagation in a short simulation."""
thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test)
n_replicas = len(sampler_states)
if n_replicas == 1:
# This test is intended for use with more than one replica
return

with self.temporary_storage_path() as storage_path:
# Create a replica exchange that propagates only 1 femtosecond
# per iteration so that positions won't change much.
move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1)
sampler = self.SAMPLER(mcmc_moves=move)
reporter = self.REPORTER(storage_path)
self.call_sampler_create(sampler, reporter,
thermodynamic_states, sampler_states,
unsampled_states)
# Set context cache attributes
sampler.energy_context_cache = cache.ContextCache(capacity=None, time_to_live=None)
sampler.sampler_context_cache = cache.ContextCache(capacity=None, time_to_live=None)
# Compute energies
sampler._compute_energies()
# Check only energy context cache has been accessed
assert sampler.energy_context_cache._lru._n_access > 0, \
f"Expected more than 0 accesses, received {sampler.energy_context_cache._lru._n_access }."
assert sampler.sampler_context_cache._lru._n_access == 0, \
f"{sampler.sampler_context_cache._lru._n_access} accesses, expected 0."

# Propagate replicas
sampler._propagate_replicas()
# Check propagation context cache has been accessed after propagation
assert sampler.sampler_context_cache._lru._n_access > 0, \
f"Expected more than 0 accesses, received {sampler.energy_context_cache._lru._n_access }."


#############

Expand Down

0 comments on commit 7e6926f

Please sign in to comment.