Skip to content

Commit

Permalink
Modified run_EEXE.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Sep 6, 2023
1 parent d040b5e commit 18bbbfe
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
15 changes: 8 additions & 7 deletions ensemble_md/cli/run_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main():
dhdl_files = [f'sim_{j}/iteration_{i - 1}/dhdl.xvg' for j in range(EEXE.n_sim)]
log_files = [f'sim_{j}/iteration_{i - 1}/md.log' for j in range(EEXE.n_sim)]
states_ = EEXE.extract_final_dhdl_info(dhdl_files)
wl_delta, weights_, counts = EEXE.extract_final_log_info(log_files)
wl_delta, weights_, counts_ = EEXE.extract_final_log_info(log_files)
print()

# 3-2. Identify swappable pairs, propose swap(s), calculate P_acc, and accept/reject swap(s)
Expand All @@ -158,6 +158,7 @@ def main():
# `combine_weights` and `update_MDP`.
states = copy.deepcopy(states_)
weights = copy.deepcopy(weights_)
counts = copy.deepcopy(counts_)
swap_pattern, swap_list = EEXE.get_swapping_pattern(dhdl_files, states_, weights_) # swap_list will only be used for modify_coords # noqa: E501

# 3-3. Perform weight correction/weight combination
Expand All @@ -180,37 +181,37 @@ def main():
# Then only weight correction will be performed
print('Note: Weight combination is deactivated because the weights are too noisy.')
weights = EEXE.weight_correction(weights, counts)
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501
else:
weights_preprocessed = EEXE.weight_correction(weights_input, counts)
if EEXE.verbose is True:
print('Performing weight combination ...')
else:
print('Performing weight combination ...', end='')
weights, g_vec = EEXE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
counts, weights, g_vec = EEXE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff == -1 and EEXE.w_combine is not None:
# only perform weight combination
print('Note: No weight correction will be performed.')
if weights_input is None:
print('Note: Weight combination is deactivated because the weights are too noisy.')
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501
else:
if EEXE.verbose is True:
print('Performing weight combination ...')
else:
print('Performing weight combination ...', end='')
weights, g_vec = EEXE.combine_weights(weights_input) # inverse-variance weighting seems worse
counts, weights, g_vec = EEXE.combine_weights(counts_, weights_input) # inverse-variance weighting seems worse # noqa: E501
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff != -1 and EEXE.w_combine is None:
# only perform weight correction
print('Note: No weight combination will be performed.')
weights = EEXE.histogram_correction(weights_input, counts)
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501
else:
print('Note: No weight correction will be performed.')
print('Note: No weight combination will be performed.')
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501

# 3-5. Modify the MDP files and swap out the GRO files (if needed)
# Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501
Expand Down
8 changes: 3 additions & 5 deletions ensemble_md/tests/test_ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def test_weight_correction(self, params_dict):

def test_combine_weights_1(self, params_dict):
"""
Here we just test the combined weights, so the values of hist does not matter.
Here we just test the combined weights, so the values of hist does not matter.
"""
EEXE = get_EEXE_instance(params_dict)
EEXE.n_tot = 6
Expand All @@ -618,7 +618,6 @@ def test_combine_weights_1(self, params_dict):
weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]]
hist = [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]

EEXE.w_combine = True
_, w_1, g_vec_1 = EEXE.combine_weights(hist, weights)
assert np.allclose(w_1, [
[0, 2.1, 3.9, 3.5],
Expand All @@ -637,7 +636,7 @@ def test_combine_weights_1(self, params_dict):

def test_combine_weights_2(self, params_dict):
"""
Here we just test the modified histograms, so the values of weights does not matter.
Here we just test the modified histograms, so the values of weights does not matter.
"""
EEXE = get_EEXE_instance(params_dict)
EEXE.n_tot = 6
Expand All @@ -647,7 +646,6 @@ def test_combine_weights_2(self, params_dict):
EEXE.state_ranges = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]]
weights = [[0, 2.1, 4.0, 3.7, 5], [0, 1.7, 1.2, 2.6, 4]]
hist = [[416, 332, 130, 71, 61], [303, 181, 123, 143, 260]]

EEXE.w_combine = True

hist_modified, _, _ = EEXE.combine_weights(hist, weights)
assert hist_modified == [[416, 332, 161, 98, 98], [332, 161, 98, 98, 178]]

0 comments on commit 18bbbfe

Please sign in to comment.