Skip to content

Commit

Permalink
Merge pull request #24 from wehs7661/refactor_w_combine
Browse files Browse the repository at this point in the history
Refactoring codes for weight combination
  • Loading branch information
wehs7661 authored Sep 3, 2023
2 parents 006b608 + 9a1373e commit a3023ff
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 99 deletions.
27 changes: 22 additions & 5 deletions docs/simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,24 @@ include parameters for data analysis here.
- :code:`acceptance`: (Optional, Default: :code:`metropolis`)
The Monte Carlo method for swapping simulations. Available options include :code:`same-state`/:code:`same_state`, :code:`metropolis`, and :code:`metropolis-eq`/:code:`metropolis_eq`.
For more details, please refer to :ref:`doc_acceptance`.
- :code:`w_combine`: (Optional, Default: :code:`False`)
Whether to combine weights across multiple replicas for an weight-updating EEXE simulations.
For more details, please refer to :ref:`doc_w_schemes`.
- :code:`w_combine`: (Optional, Default: :code:`None`)
The type of weights to be combined across multiple replicas in a weight-updating EEXE simulation. The following options are available:

- :code:`None`: No weight combination.
- :code:`final`: Combine the final weights.
- :code:`avg`: Combine the weights averaged over from last time the Wang-Landau incrementor was updated.

For more details about weight combination, please refer to :ref:`doc_w_schemes`.

- :code:`rmse_cutoff`: (Optional, Default: :code:`None`)
The cutoff for the root-mean-square error (RMSE) between the weights of the current iteration
and the weights averaged over from the last time the Wang-Landau incrementor was updated.
For each replica, the RMSE between the averaged weights and the current weights will be calculated.
When :code:`rmse_cutoff` is specified, weight combination will be performed only if the maximum RMSE across all replicas
is smaller than the cutoff. Otherwise, weight combination is deactivated (even if :code:`w_combine` is specified)
because a larger RMSE indicates that the weights are noisy and should not be combined.
The default value is infinity, which means that weight combination will always be performed if :code:`w_combine` is specified.
The units of the cutoff are :math:`k_B T`.
- :code:`N_cutoff`: (Optional, Default: 1000)
The histogram cutoff. -1 means that no histogram correction will be performed.
- :code:`n_ex`: (Optional, Default: 1)
Expand Down Expand Up @@ -352,7 +367,8 @@ include parameters for data analysis here.
-------------------------------
For convenience, here is a template of the input YAML file, with each optional parameter specified with the default and required
parameters left with a blank. Note that specifying :code:`null` is the same as leaving the parameter unspecified (i.e. :code:`None`).

Note that the default value :code:`None` for the parameter :code:`rmse_cutoff` will be converted to
infinity internally.

.. code-block:: yaml
Expand All @@ -373,7 +389,8 @@ parameters left with a blank. Note that specifying :code:`null` is the same as l
add_swappables: null
proposal: 'exhaustive'
acceptance: 'metropolis'
w_combine: False
w_combine: null
rmse_cutoff: null
N_cutoff: 1000
n_ex: 1
mdp_args: null
Expand Down
68 changes: 43 additions & 25 deletions ensemble_md/cli/run_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from datetime import datetime

from ensemble_md.utils import utils
from ensemble_md.utils import gmx_parser
from ensemble_md.ensemble_EXE import EnsembleEXE


Expand Down Expand Up @@ -154,45 +153,64 @@ def main():
# Note after `get_swapping_pattern`, `states_` and `weights_` won't be necessarily
# since they are updated by `get_swapping_pattern`. (Even if the function does not explicitly
# returns `states_` and `weights_`, `states_` and `weights_` can still be different after
# the use of the function.) Therefore, here we create copyes for `states_` and `weights_`
# the use of the function.) Therefore, here we create copies for `states_` and `weights_`
# before the use of `get_swapping_pattern`, so we can use them in `histogram_correction`,
# `combine_weights` and `update_MDP`.
states = copy.deepcopy(states_)
weights = copy.deepcopy(weights_)
swap_pattern, swap_list = EEXE.get_swapping_pattern(dhdl_files, states_, weights_) # swap_list will only be used for modify_coords # noqa: E501

# 3-3. Calculate the weights averaged since the last update of the Wang-Landau incrementor.
# Note that the averaged weights are used for histogram correction/weight combination.
# For calculating the acceptance ratio (inside get_swapping_pattern), final weights should be used.
if EEXE.N_cutoff != -1 or EEXE.w_combine is True:
# 3-3. Perform histogram correction/weight combination
if wl_delta != [None for i in range(EEXE.n_sim)]: # weight-updating
print(f'\nCurrent Wang-Landau incrementors: {wl_delta}\n')

# (1) First we prepare the weights to be combined.
# Note that although averaged weights are sometimes used for histogram correction/weight combination,
# the final weights are always used for calculating the acceptance ratio.
if EEXE.N_cutoff != -1 or EEXE.w_combine is not None:
# Only when histogram correction/weight combination is needed.
weights_avg, weights_err = EEXE.get_averaged_weights(log_files)
weights_input = EEXE.prepare_weights(weights_avg, weights) # weights_input is for weight combination # noqa: E501

# 3-4. Perform histogram correction/weight combination
# Note that we never use final weights but averaged weights here.
# (2) Now we perform histogram correction/weight combination.
# The product of this step should always be named as "weights" to be used in update_MDP
if wl_delta != [None for i in range(EEXE.n_sim)]: # weight-updating
print(f'\nCurrent Wang-Landau incrementors: {wl_delta}')
if EEXE.N_cutoff != -1 and EEXE.w_combine is True:
if EEXE.N_cutoff != -1 and EEXE.w_combine is not None:
# perform both
weights_avg = EEXE.histogram_correction(weights_avg, counts)
weights, g_vec = EEXE.combine_weights(weights_avg) # inverse-variance weighting seems worse
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff == -1 and EEXE.w_combine is True:
if weights_input is None:
# Then only histogram correction will be performed
print('Note: Weight combination is deactivated because the weights are too noisy.')
weights = EEXE.histogram_correction(weights, counts)
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights
else:
weights_preprocessed = EEXE.histogram_correction(weights_input, counts)
if EEXE.verbose is True:
print('Performing weight combination ...')
else:
print('Performing weight combination ...', end='')
weights, g_vec = EEXE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff == -1 and EEXE.w_combine is not None:
# only perform weight combination
print('\nNote: No histogram correction will be performed.')
weights, g_vec = EEXE.combine_weights(weights_avg) # inverse-variance weighting seems worse
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff != -1 and EEXE.w_combine is False:
print('Note: No histogram correction will be performed.')
if weights_input is None:
print('Note: Weight combination is deactivated because the weights are too noisy.')
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights
else:
if EEXE.verbose is True:
print('Performing weight combination ...')
else:
print('Performing weight combination ...', end='')
weights, g_vec = EEXE.combine_weights(weights_input) # inverse-variance weighting seems worse
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff != -1 and EEXE.w_combine is None:
# only perform histogram correction
print('\nNote: No weight combination will be performed.')
weights = EEXE.histogram_correction(weights_avg, counts)
print('Note: No weight combination will be performed.')
weights = EEXE.histogram_correction(weights_input, counts)
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights
else:
weights_current = [gmx_parser.parse_log(log_files[i])[0][-1] for i in range(EEXE.n_sim)]
w_for_printing = EEXE.combine_weights(weights_current)[1]
print('\nNote: No histogram correction will be performed.')
print('Note: No histogram correction will be performed.')
print('Note: No weight combination will be performed.')
print(f'The alchemical weights of all states: \n {list(np.round(w_for_printing, decimals=3))}')
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights

# 3-5. Modify the MDP files and swap out the GRO files (if needed)
# Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501
Expand Down
Loading

0 comments on commit a3023ff

Please sign in to comment.