Skip to content

Commit

Permalink
Added new parameter rmse_cutoff and modified w_combine; Modified run_…
Browse files Browse the repository at this point in the history
…EXEE.py to use prepare_weights
  • Loading branch information
wehs7661 committed Aug 29, 2023
1 parent 45b1a39 commit 6f7e787
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 34 deletions.
22 changes: 18 additions & 4 deletions docs/simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,22 @@ include parameters for data analysis here.
- :code:`acceptance`: (Optional, Default: :code:`metropolis`)
The Monte Carlo method for swapping simulations. Available options include :code:`same-state`/:code:`same_state`, :code:`metropolis`, and :code:`metropolis-eq`/:code:`metropolis_eq`.
For more details, please refer to :ref:`doc_acceptance`.
- :code:`w_combine`: (Optional, Default: :code:`False`)
Whether to combine weights across multiple replicas for an weight-updating EEXE simulations.
For more details, please refer to :ref:`doc_w_schemes`.
- :code:`w_combine`: (Optional, Default: :code:`None`)
The type of weights to be combined across multiple replicas in a weight-updating EEXE simulation. The following options are available:

- :code:`None`: No weight combination.
- :code:`final`: Combine the final weights.
- :code:`avg`: Combine the weights averaged over from last time the Wang-Landau incrementor was updated. Notably, the time-averaged weights tend to be very noisy at the beginning of the simulation and can drive the combined weights in a bad direction. Therefore, we recommend specifying the parameter :code:`rmse_cutoff` to only use the time-averaged weights when the weights are not changing too much. For more details, check the description of the parameter :code:`rmse_cutoff` below.

For more details about weight combination, please refer to :ref:`doc_w_schemes`.

- :code:`rmse_cutoff`: (Optional, Default: :code:`None`)
The cutoff for the root-mean-square error (RMSE) between the weights at the end of the current iteration
and the weights averaged over from the last time the Wang-Landau incrementor was updated.
For each replica, the time-averaged weights will be used in weight combination only if the RMSE is smaller than the cutoff.
Otherwise, the final weights will still be used. If this parameter is not specified, then time-averaged weights will always be used, which could be problematic
since time-averaged weights tend to be very noisy at the beginning of the simulation. Note that this parameter is only meanful when :code:`w_combine` is set to :code:`avg`.
The units of the cutoff are :math:`k_B T`.
- :code:`N_cutoff`: (Optional, Default: 1000)
The histogram cutoff. -1 means that no histogram correction will be performed.
- :code:`n_ex`: (Optional, Default: 1)
Expand Down Expand Up @@ -373,7 +386,8 @@ parameters left with a blank. Note that specifying :code:`null` is the same as l
add_swappables: null
proposal: 'exhaustive'
acceptance: 'metropolis'
w_combine: False
w_combine: null
rmse_cutoff: null
N_cutoff: 1000
n_ex: 1
mdp_args: null
Expand Down
23 changes: 11 additions & 12 deletions ensemble_md/cli/run_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from datetime import datetime

from ensemble_md.utils import utils
from ensemble_md.utils import gmx_parser
from ensemble_md.ensemble_EXE import EnsembleEXE


Expand Down Expand Up @@ -164,35 +163,35 @@ def main():
# 3-3. Perform histogram correction/weight combination
if wl_delta != [None for i in range(EEXE.n_sim)]: # weight-updating
print(f'\nCurrent Wang-Landau incrementors: {wl_delta}')

# (1) First we prepare the weights to be combined. For each replica, weights to be combined
# could be either the final weights or the weights averaged since the last update of the
# Wang-Landau incrementor. See the function `prepare_weights` for details.
# Note that although averaged weights are sometimes used for histogram correction/weight combination,
# the final weights are always used for calculating the acceptance ratio.
if EEXE.N_cutoff != -1 or EEXE.w_combine is True:
if EEXE.N_cutoff != -1 or EEXE.w_combine is not None:
# Only when histogram correction/weight combination is needed.
weights_avg, weights_err = EEXE.get_averaged_weights(log_files)
weights_input = EEXE.prepare_weights(weights_avg, weights) # weights_input is for weight combination # noqa: E501

# (2) Now we perform histogram correction/weight combination.
# The product of this step should always be named as "weights" to be used in update_MDP
if EEXE.N_cutoff != -1 and EEXE.w_combine is True:
if EEXE.N_cutoff != -1 and EEXE.w_combine is not None:
# perform both
weights_avg = EEXE.histogram_correction(weights_avg, counts)
weights, g_vec = EEXE.combine_weights(weights_avg) # inverse-variance weighting seems worse
weights_preprocessed = EEXE.histogram_correction(weights_input, counts)
weights, g_vec = EEXE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff == -1 and EEXE.w_combine is True:
elif EEXE.N_cutoff == -1 and EEXE.w_combine is not None:
# only perform weight combination
print('\nNote: No histogram correction will be performed.')
weights, g_vec = EEXE.combine_weights(weights_avg) # inverse-variance weighting seems worse
weights, g_vec = EEXE.combine_weights(weights_input) # inverse-variance weighting seems worse
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff != -1 and EEXE.w_combine is False:
elif EEXE.N_cutoff != -1 and EEXE.w_combine is None:
# only perform histogram correction
print('\nNote: No weight combination will be performed.')
weights = EEXE.histogram_correction(weights_avg, counts)
weights = EEXE.histogram_correction(weights_input, counts)
else:
weights_current = [gmx_parser.parse_log(log_files[i])[0][-1] for i in range(EEXE.n_sim)]
w_for_printing = EEXE.combine_weights(weights_current)[1]
w_for_printing = EEXE.combine_weights(weights)[1]
print('\nNote: No histogram correction will be performed.')
print('Note: No weight combination will be performed.')
print(f'The alchemical weights of all states: \n {list(np.round(w_for_printing, decimals=3))}')
Expand Down
40 changes: 26 additions & 14 deletions ensemble_md/ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def set_params(self, analysis):
"nst_sim": None,
"proposal": 'exhaustive',
"acceptance": "metropolis",
"w_combine": False,
"w_combine": None,
"rmse_cutoff": None,
"N_cutoff": 1000,
"n_ex": 'N^3', # only active for multiple swaps.
"verbose": True,
Expand Down Expand Up @@ -195,12 +196,22 @@ def set_params(self, analysis):
if self.acceptance not in [None, 'same-state', 'same_state', 'metropolis', 'metropolis-eq', 'metropolis_eq']:
raise ParameterError("The specified acceptance scheme is not available. Available options include 'same-state', 'metropolis', and 'metropolis-eq'.") # noqa: E501

if self.w_combine not in [None, 'final', 'avg']:
raise ParameterError("The specified type of weight to be combined is not available. Available options include 'final' and 'avg'.") # noqa: E501

if self.df_method not in [None, 'TI', 'BAR', 'MBAR']:
raise ParameterError("The specified free energy estimator is not available. Available options include 'TI', 'BAR', and 'MBAR'.") # noqa: E501

if self.err_method not in [None, 'propagate', 'bootstrap']:
raise ParameterError("The specified method for error estimation is not available. Available options include 'propagate', and 'bootstrap'.") # noqa: E501

if self.w_combine == 'avg' and self.rmse_cutoff is None:
self.warnings.append('Warning: We recommend setting rmse_cutoff when w_combine is set to "avg".')

if self.rmse_cutoff is not None:
if type(self.rmse_cutoff) is not float:
raise ParameterError("The parameter 'rmse_cutoff' should be a float.")

params_int = ['n_sim', 'n_iter', 's', 'N_cutoff', 'df_spacing', 'n_ckpt', 'n_bootstrap'] # integer parameters # noqa: E501
if self.nst_sim is not None:
params_int.append('nst_sim')
Expand All @@ -215,6 +226,8 @@ def set_params(self, analysis):
params_pos = ['n_sim', 'n_iter', 'n_ckpt', 'df_spacing', 'n_bootstrap'] # positive parameters
if self.nst_sim is not None:
params_pos.append('nst_sim')
if self.rmse_cutoff is not None:
params_pos.append('rmse_cutoff')
for i in params_pos:
if getattr(self, i) <= 0:
raise ParameterError(f"The parameter '{i}' should be positive.")
Expand Down Expand Up @@ -242,7 +255,7 @@ def set_params(self, analysis):
if type(getattr(self, i)) != str:
raise ParameterError(f"The parameter '{i}' should be a string.")

params_bool = ['verbose', 'rm_cpt', 'w_combine', 'msm', 'free_energy', 'subsampling_avg']
params_bool = ['verbose', 'rm_cpt', 'msm', 'free_energy', 'subsampling_avg']
for i in params_bool:
if type(getattr(self, i)) != bool:
raise ParameterError(f"The parameter '{i}' should be a boolean variable.")
Expand Down Expand Up @@ -305,11 +318,11 @@ def set_params(self, analysis):
self.equilibrated_weights = [None for i in range(self.n_sim)]

if self.fixed_weights is True:
if self.N_cutoff != -1 or self.w_combine is not False:
if self.N_cutoff != -1 or self.w_combine is not None:
self.warnings.append('Warning: The histogram correction/weight combination method is specified but will not be used since the weights are fixed.') # noqa: E501
# In the case that the warning is ignored, enforce the defaults.
self.N_cutoff = -1
self.w_combine = False
self.w_combine = None

if 'lmc_seed' in self.template and self.template['lmc_seed'] != -1:
self.warnings.append('Warning: We recommend setting lmc_seed as -1 so the random seed is different for each iteration.') # noqa: E501
Expand Down Expand Up @@ -483,7 +496,7 @@ def print_params(self, params_analysis=False):
print(f"Verbose log file: {self.verbose}")
print(f"Proposal scheme: {self.proposal}")
print(f"Acceptance scheme for swapping simulations: {self.acceptance}")
print(f"Whether to perform weight combination: {self.w_combine}")
print(f"Type of weights to be combined: {self.w_combine}")
print(f"Histogram cutoff: {self.N_cutoff}")
print(f"Number of replicas: {self.n_sim}")
print(f"Number of iterations: {self.n_iter}")
Expand Down Expand Up @@ -810,9 +823,8 @@ def prepare_weights(self, weights_avg, weights_final):
else:
print('Note: The final weights will be used for weight combination, as the time-averaged weights still fluctuate too much.') # noqa: E501
weights_output.append(weights_final[i])

return weights_output

return weights_output

@staticmethod
def identify_swappable_pairs(states, state_ranges, neighbor_exchange, add_swappables=None):
Expand Down Expand Up @@ -1280,7 +1292,7 @@ def combine_weights(self, weights, weights_err=None):

return weights, g_vec

def run_grompp(self, n, swap_pattern):
def _run_grompp(self, n, swap_pattern):
"""
Prepares TPR files for the simulation ensemble using the GROMACS :code:`grompp` command.
Expand Down Expand Up @@ -1345,7 +1357,7 @@ def run_grompp(self, n, swap_pattern):
if code_list != [0] * self.n_sim:
MPI.COMM_WORLD.Abort(1) # Doesn't matter what non-zero returncode we put here as the code from GROMACS will be printed before this point anyway. # noqa: E501

def run_mdrun(self, n):
def _run_mdrun(self, n):
"""
Executes GROMACS mdrun commands in parallel.
Expand Down Expand Up @@ -1412,14 +1424,14 @@ def run_EEXE(self, n, swap_pattern=None):
print(iter_str + '\n' + '=' * (len(iter_str) - 1))

# 1st synchronizing point for all MPI processes: To make sure ranks other than 0 will not start executing
# run_grompp earlier and mess up the order of printing.
# _run_grompp earlier and mess up the order of printing.
comm.barrier()

# Generating all required TPR files simultaneously, then run all simulations simultaneously.
# No synchronizing point is needed between run_grompp and run_mdrun, since once rank i finishes run_grompp,
# it should run run_mdrun in the same working directory, so there won't be any I/O error.
self.run_grompp(n, swap_pattern)
self.run_mdrun(n)
# No synchronizing point is needed between _run_grompp and _run_mdrun, since once rank i finishes _run_grompp,
# it should run _run_mdrun in the same working directory, so there won't be any I/O error.
self._run_grompp(n, swap_pattern)
self._run_mdrun(n)

# 2nd synchronizaing point for all MPI processes: To make sure no rank will start getting to the next
# iteration earlier than the others. For example, if rank 0 finishes the mdrun command earlier, we don't
Expand Down
4 changes: 2 additions & 2 deletions ensemble_md/tests/test_ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_set_params(self, params_dict):
# 2. Check the default values of the parameters not specified in params.yaml
assert EEXE.proposal == "exhaustive"
assert EEXE.acceptance == "metropolis"
assert EEXE.w_combine is False
assert EEXE.w_combine is None
assert EEXE.N_cutoff == 1000
assert EEXE.n_ex == 'N^3'
assert EEXE.verbose is True
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_print_params(self, capfd, params_dict):
L += "Verbose log file: True\n"
L += "Proposal scheme: exhaustive\n"
L += "Acceptance scheme for swapping simulations: metropolis\n"
L += "Whether to perform weight combination: False\n"
L += "Type of weights to be combined: None\n"
L += "Histogram cutoff: 1000\nNumber of replicas: 4\nNumber of iterations: 10\n"
L += "Number of attempted swaps in one exchange interval: N^3\n"
L += "Length of each replica: 1.0 ps\nFrequency for checkpointing: 100 iterations\n"
Expand Down
4 changes: 2 additions & 2 deletions ensemble_md/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def format_time(t):
return t_str


def autoconvert(s):
def _autoconvert(s):
"""
Converts input to a numerical type if possible. Used for the MDP parser.
Modified from `utilities.py in GromacsWrapper <https://github.com/Becksteinlab/GromacsWrapper>`_.
Expand Down Expand Up @@ -179,7 +179,7 @@ def autoconvert(s):
raise ValueError("Failed to autoconvert {0!r}".format(s))


def get_subplot_dimension(n_panels):
def _get_subplot_dimension(n_panels):
"""
Gets the numbers of rows and columns in a subplot such that
the arrangement of the .
Expand Down

0 comments on commit 6f7e787

Please sign in to comment.