Skip to content

Commit

Permalink
Merge pull request #25 from wehs7661/hist_correction
Browse files Browse the repository at this point in the history
Implement the method for histogram correction
  • Loading branch information
wehs7661 authored Oct 24, 2023
2 parents a3023ff + 9ee86f9 commit 9cb910f
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 84 deletions.
107 changes: 70 additions & 37 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,15 +965,76 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None):
plt.savefig(f'{swap_type}_swaps.png', dpi=600)


def get_dg_evolution(log_file, start_state, end_state):
def get_g_evolution(log_files, N_states, avg_frac=0):
"""
For weight-updating simulations, gets the time series of the alchemical
weights of all states.
Parameters
----------
log_files : list
The list of log file names.
N_states : int
The total number of states in the whole alchemical range.
avg_frac : float
The fraction of the last part of the simulation to be averaged. The
default is 0, which means no averaging.
Returns
-------
g_vecs_all : list
The alchemical weights of all states as a function of time.
It should be a list of lists.
g_vecs_avg : list
The alchemical weights of all states averaged over the last part of
the simulation. If :code:`avg_frac` is 0, :code:`None` will be returned.
"""
g_vecs_all = []
for log_file in log_files:
f = open(log_file, "r")
lines = f.readlines()
f.close()

n = -1
find_equil = False
for line in lines:
n += 1
if "Count G(in kT)" in line: # this line is lines[n]
w = [] # the list of weights at this time frame
for i in range(1, N_states + 1):
if "<<" in lines[n + i]:
w.append(float(lines[n + i].split()[-3]))
else:
w.append(float(lines[n + i].split()[-2]))

if find_equil is False:
g_vecs_all.append(w)

if "Weights have equilibrated" in line:
find_equil = True
w = [float(i) for i in lines[n - 2].split(':')[-1].split()]
g_vecs_all.append(w)
break

if avg_frac != 0:
n_avg = int(avg_frac * len(g_vecs_all))
g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0)
else:
g_vecs_avg = None

return g_vecs_all, g_vecs_avg


def get_dg_evolution(log_files, start_state, end_state):
"""
For weight-updating simulations, gets the time series of the weight
difference (:math:`Δg = g_2-g_1`) between the specified states.
Parameters
----------
log_file : str
The log file name.
log_files : list
The list of log file names.
start_state : int
The index of the state (starting from 0) whose weight is :math:`g_1`.
end_state : int
Expand All @@ -984,45 +1045,22 @@ def get_dg_evolution(log_file, start_state, end_state):
dg : list
A list of :math:`Δg` values.
"""
f = open(log_file, "r")
lines = f.readlines()
f.close()

n = -1
find_equil = False
dg = []
N_states = end_state - start_state + 1 # number of states for the range of insterest
for line in lines:
n += 1
if "Count G(in kT)" in line: # this line is lines[n]
w = [] # the list of weights at this time frame
for i in range(1, N_states + 1):
if "<<" in lines[n + i]:
w.append(float(lines[n + i].split()[-3]))
else:
w.append(float(lines[n + i].split()[-2]))

if find_equil is False:
dg.append(w[end_state] - w[start_state])

if "Weights have equilibrated" in line:
find_equil = True
w = [float(i) for i in lines[n - 2].split(':')[-1].split()]
dg.append(w[end_state] - w[start_state])
break
g_vecs = get_g_evolution(log_files, N_states)
dg = [g_vecs[i][end_state] - g_vecs[i][start_state] for i in range(len(g_vecs))]

return dg


def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2):
def plot_dg_evolution(log_files, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2):
"""
For weight-updating simulations, plots the time series of the weight
difference (:math:`Δg = g_2-g_1`) between the specified states.
Parameters
----------
log_file : str or list
The log file name or a list of log file names.
log_files : list
The list of log file names.
start_state : int
The index of the state (starting from 0) whose weight is :math:`g_1`.
end_state : int
Expand All @@ -1035,12 +1073,7 @@ def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1,
The time interval between two consecutive frames in the log file. The
default is 2 ps.
"""
if isinstance(log_file, str):
dg = get_dg_evolution(log_file, start_state, end_state)
else:
dg = []
for f in log_file:
dg += get_dg_evolution(f, start_state, end_state)
dg = get_dg_evolution(log_files, start_state, end_state)

# Now we plot
dg = dg[start_idx:end_idx]
Expand Down
37 changes: 19 additions & 18 deletions ensemble_md/cli/run_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,71 +146,72 @@ def main():
dhdl_files = [f'sim_{j}/iteration_{i - 1}/dhdl.xvg' for j in range(EEXE.n_sim)]
log_files = [f'sim_{j}/iteration_{i - 1}/md.log' for j in range(EEXE.n_sim)]
states_ = EEXE.extract_final_dhdl_info(dhdl_files)
wl_delta, weights_, counts = EEXE.extract_final_log_info(log_files)
wl_delta, weights_, counts_ = EEXE.extract_final_log_info(log_files)
print()

# 3-2. Identify swappable pairs, propose swap(s), calculate P_acc, and accept/reject swap(s)
# Note after `get_swapping_pattern`, `states_` and `weights_` won't be necessarily
# since they are updated by `get_swapping_pattern`. (Even if the function does not explicitly
# returns `states_` and `weights_`, `states_` and `weights_` can still be different after
# the use of the function.) Therefore, here we create copies for `states_` and `weights_`
# before the use of `get_swapping_pattern`, so we can use them in `histogram_correction`,
# before the use of `get_swapping_pattern`, so we can use them in `weight_correction`,
# `combine_weights` and `update_MDP`.
states = copy.deepcopy(states_)
weights = copy.deepcopy(weights_)
counts = copy.deepcopy(counts_)
swap_pattern, swap_list = EEXE.get_swapping_pattern(dhdl_files, states_, weights_) # swap_list will only be used for modify_coords # noqa: E501

# 3-3. Perform histogram correction/weight combination
# 3-3. Perform weight correction/weight combination
if wl_delta != [None for i in range(EEXE.n_sim)]: # weight-updating
print(f'\nCurrent Wang-Landau incrementors: {wl_delta}\n')

# (1) First we prepare the weights to be combined.
# Note that although averaged weights are sometimes used for histogram correction/weight combination,
# Note that although averaged weights are sometimes used for weight correction/weight combination,
# the final weights are always used for calculating the acceptance ratio.
if EEXE.N_cutoff != -1 or EEXE.w_combine is not None:
# Only when histogram correction/weight combination is needed.
# Only when weight correction/weight combination is needed.
weights_avg, weights_err = EEXE.get_averaged_weights(log_files)
weights_input = EEXE.prepare_weights(weights_avg, weights) # weights_input is for weight combination # noqa: E501

# (2) Now we perform histogram correction/weight combination.
# (2) Now we perform weight correction/weight combination.
# The product of this step should always be named as "weights" to be used in update_MDP
if EEXE.N_cutoff != -1 and EEXE.w_combine is not None:
# perform both
if weights_input is None:
# Then only histogram correction will be performed
# Then only weight correction will be performed
print('Note: Weight combination is deactivated because the weights are too noisy.')
weights = EEXE.histogram_correction(weights, counts)
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights
weights = EEXE.weight_correction(weights, counts)
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501
else:
weights_preprocessed = EEXE.histogram_correction(weights_input, counts)
weights_preprocessed = EEXE.weight_correction(weights_input, counts)
if EEXE.verbose is True:
print('Performing weight combination ...')
else:
print('Performing weight combination ...', end='')
weights, g_vec = EEXE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
counts, weights, g_vec = EEXE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff == -1 and EEXE.w_combine is not None:
# only perform weight combination
print('Note: No histogram correction will be performed.')
print('Note: No weight correction will be performed.')
if weights_input is None:
print('Note: Weight combination is deactivated because the weights are too noisy.')
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501
else:
if EEXE.verbose is True:
print('Performing weight combination ...')
else:
print('Performing weight combination ...', end='')
weights, g_vec = EEXE.combine_weights(weights_input) # inverse-variance weighting seems worse
counts, weights, g_vec = EEXE.combine_weights(counts_, weights_input) # inverse-variance weighting seems worse # noqa: E501
EEXE.g_vecs.append(g_vec)
elif EEXE.N_cutoff != -1 and EEXE.w_combine is None:
# only perform histogram correction
# only perform weight correction
print('Note: No weight combination will be performed.')
weights = EEXE.histogram_correction(weights_input, counts)
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501
else:
print('Note: No histogram correction will be performed.')
print('Note: No weight correction will be performed.')
print('Note: No weight combination will be performed.')
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501

# 3-5. Modify the MDP files and swap out the GRO files (if needed)
# Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501
Expand Down
Loading

0 comments on commit 9cb910f

Please sign in to comment.