From 18bbbfedeba21b3bf61538ceec6364086a210dbe Mon Sep 17 00:00:00 2001 From: Wei-Tse Hsu Date: Wed, 6 Sep 2023 15:06:13 -0600 Subject: [PATCH] Modified run_EEXE.py --- ensemble_md/cli/run_EEXE.py | 15 ++++++++------- ensemble_md/tests/test_ensemble_EXE.py | 8 +++----- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/ensemble_md/cli/run_EEXE.py b/ensemble_md/cli/run_EEXE.py index 9a47a78e..cfc69c82 100644 --- a/ensemble_md/cli/run_EEXE.py +++ b/ensemble_md/cli/run_EEXE.py @@ -146,7 +146,7 @@ def main(): dhdl_files = [f'sim_{j}/iteration_{i - 1}/dhdl.xvg' for j in range(EEXE.n_sim)] log_files = [f'sim_{j}/iteration_{i - 1}/md.log' for j in range(EEXE.n_sim)] states_ = EEXE.extract_final_dhdl_info(dhdl_files) - wl_delta, weights_, counts = EEXE.extract_final_log_info(log_files) + wl_delta, weights_, counts_ = EEXE.extract_final_log_info(log_files) print() # 3-2. Identify swappable pairs, propose swap(s), calculate P_acc, and accept/reject swap(s) @@ -158,6 +158,7 @@ def main(): # `combine_weights` and `update_MDP`. states = copy.deepcopy(states_) weights = copy.deepcopy(weights_) + counts = copy.deepcopy(counts_) swap_pattern, swap_list = EEXE.get_swapping_pattern(dhdl_files, states_, weights_) # swap_list will only be used for modify_coords # noqa: E501 # 3-3. Perform weight correction/weight combination @@ -180,37 +181,37 @@ def main(): # Then only weight correction will be performed print('Note: Weight combination is deactivated because the weights are too noisy.') weights = EEXE.weight_correction(weights, counts) - _ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights + _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 else: weights_preprocessed = EEXE.weight_correction(weights_input, counts) if EEXE.verbose is True: print('Performing weight combination ...') else: print('Performing weight combination ...', end='') - weights, g_vec = EEXE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 + counts, weights, g_vec = EEXE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 EEXE.g_vecs.append(g_vec) elif EEXE.N_cutoff == -1 and EEXE.w_combine is not None: # only perform weight combination print('Note: No weight correction will be performed.') if weights_input is None: print('Note: Weight combination is deactivated because the weights are too noisy.') - _ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights + _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 else: if EEXE.verbose is True: print('Performing weight combination ...') else: print('Performing weight combination ...', end='') - weights, g_vec = EEXE.combine_weights(weights_input) # inverse-variance weighting seems worse + counts, weights, g_vec = EEXE.combine_weights(counts_, weights_input) # inverse-variance weighting seems worse # noqa: E501 EEXE.g_vecs.append(g_vec) elif EEXE.N_cutoff != -1 and EEXE.w_combine is None: # only perform weight correction print('Note: No weight combination will be performed.') weights = EEXE.histogram_correction(weights_input, counts) - _ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights + _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 else: print('Note: No weight correction will be performed.') print('Note: No weight combination will be performed.') - _ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights + _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 # 3-5. Modify the MDP files and swap out the GRO files (if needed) # Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501 diff --git a/ensemble_md/tests/test_ensemble_EXE.py b/ensemble_md/tests/test_ensemble_EXE.py index 38a1bda7..a832515b 100644 --- a/ensemble_md/tests/test_ensemble_EXE.py +++ b/ensemble_md/tests/test_ensemble_EXE.py @@ -607,7 +607,7 @@ def test_weight_correction(self, params_dict): def test_combine_weights_1(self, params_dict): """ - Here we just test the combined weights, so the values of hist does not matter. + Here we just test the combined weights, so the values of hist does not matter. """ EEXE = get_EEXE_instance(params_dict) EEXE.n_tot = 6 @@ -618,7 +618,6 @@ def test_combine_weights_1(self, params_dict): weights = [[0, 2.1, 4.0, 3.7], [0, 1.7, 1.2, 2.6], [0, -0.4, 0.9, 1.9]] hist = [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]] - EEXE.w_combine = True _, w_1, g_vec_1 = EEXE.combine_weights(hist, weights) assert np.allclose(w_1, [ [0, 2.1, 3.9, 3.5], @@ -637,7 +636,7 @@ def test_combine_weights_1(self, params_dict): def test_combine_weights_2(self, params_dict): """ - Here we just test the modified histograms, so the values of weights does not matter. + Here we just test the modified histograms, so the values of weights does not matter. """ EEXE = get_EEXE_instance(params_dict) EEXE.n_tot = 6 @@ -647,7 +646,6 @@ def test_combine_weights_2(self, params_dict): EEXE.state_ranges = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]] weights = [[0, 2.1, 4.0, 3.7, 5], [0, 1.7, 1.2, 2.6, 4]] hist = [[416, 332, 130, 71, 61], [303, 181, 123, 143, 260]] - - EEXE.w_combine = True + hist_modified, _, _ = EEXE.combine_weights(hist, weights) assert hist_modified == [[416, 332, 161, 98, 98], [332, 161, 98, 98, 178]]