Skip to content

Commit

Permalink
Add function to create concatenated end state trajectories
Browse files Browse the repository at this point in the history
  • Loading branch information
ajfriedman22 committed Oct 9, 2024
1 parent 01c4de3 commit d28ded5
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
87 changes: 87 additions & 0 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,3 +1330,90 @@ def get_delta_w_updates(log_file, plot=False):
plt.savefig('delta_w_updates.png', dpi=600)

return t_updates, delta_w_updates, equil

def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_rep_pattern, ps_per_frame):
import pandas as pd
import os
import mdtraj as md

#Determine how many end states are present, which simulations and lambdas those end states correspond to
state_name = ['A']
considered_swaps = [[0,0]]
cat = ord('A') + 1
for swap in swap_rep_pattern:
part_1, part_2 = swap
if part_1 in considered_swaps and part_2 in considered_swaps:
continue
elif part_1 in considered_swaps:
index = considered_swaps.index(part_1)
state_name.append(state_name[index])
considered_swaps.append(part_2)
elif part_2 in considered_swaps:
index = considered_swaps.index(part_2)
state_name.append(state_name[index])
considered_swaps.append(part_1)
else:
state_name.append(chr(cat))
state_name.append(chr(cat))
considered_swaps.append(part_1)
considered_swaps.append(part_2)
cat += 1
for i in range(n_sim):
for j in [0, 1]:
if [i, j] not in considered_swaps:
state_name.append(chr(cat))
considered_swaps.append([i, j])
cat += 1

#Determine which frames correspond to which end states
state_frame_df = pd.DataFrame()
for n in range(n_sim):
for i in range(n_iter):
l0_frame, l1_frame = [],[]
dhdl_file = open(f'{working_dir}/sim_{n}/iteration_{i}/dhdl.xvg', 'r').readlines()
start = True
for line in dhdl_file:
split_line = line.split(' ')
while '' in split_line:
split_line.remove('')
if '#' not in split_line[0] and '@' not in split_line[0]:
time = float(split_line[0])
if start:
start_time = time
start = False
state = float(split_line[1])
if time%ps_per_frame == 0:
if state in l0_states:
l0_frame.append(int((time-start_time)/ps_per_frame))
elif state in l1_states:
l1_frame.append(int((time-start_time)/ps_per_frame))
if len(l0_frame) != 0:
df_0 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l0_frame, 'Lambda': 0})
state_frame_df = pd.concat([state_frame_df, df_0])
if len(l1_frame) != 0:
df_1 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l1_frame, 'Lambda': 1})
state_frame_df = pd.concat([state_frame_df, df_1])

#Concatenate all frames from each set of trajectories for each end state
unique_states = list(set(state_name))
for state in unique_states:
indices = [i for i, value in enumerate(state_name) if value == state]
for i, index in enumerate(indices):
rep, l = considered_swaps[index]
started = False
if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'):
name = 'confout_backup'
else:
name = 'confout'
for iteration in range(n_iter):
frames_select = state_frame_df[(state_frame_df['Sim'] == rep) & (state_frame_df['Iteration'] == iteration) & (state_frame_df['Lambda'] == l)]['Frame'].to_numpy()
if len(frames_select) != 0:
if not started:
traj = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro')
started = True
else:
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro')
traj = md.join(traj, traj_add)
traj.save_xtc(f'{working_dir}/analysis/{state}_{rep}.xtc')


11 changes: 11 additions & 0 deletions ensemble_md/cli/analyze_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
warnings.simplefilter(action='ignore', category=UserWarning)

from ensemble_md.utils import utils # noqa: E402
from ensemble_md.utils import gmx_parser # noqa: E402
from ensemble_md.analysis import analyze_traj # noqa: E402
from ensemble_md.analysis import analyze_matrix # noqa: E402
from ensemble_md.analysis import msm_analysis # noqa: E402
Expand Down Expand Up @@ -120,6 +121,7 @@ def main():
print('\nData analysis of the simulation ensemble')
print('========================================')


# Section 1. Analysis based on transitions between state sets
print('[ Section 1. Analysis based on transitions between state sets/replicas ]')
section_idx += 1
Expand All @@ -128,6 +130,15 @@ def main():
print('1-0. Reading in the replica-space trajectory ...')
rep_trajs = np.load(args.rep_trajs) # Shape: (n_sim, n_iter)

# ***** Testing Section *******
if REXEE.modify_coords is not None:
l0, l1, ps_per_frame = gmx_parser.get_end_states(f'{REXEE.working_dir}/sim_0/iteration_0/expanded.mdp')
n_sim, n_iter = np.shape(rep_trajs)
if REXEE.swap_rep_pattern is None:
raise Exception('MT-REXEE trajectory analysis requires swap_rep_pattern to be defined')
analyze_traj.end_states_only_traj(REXEE.working_dir, n_sim, n_iter, l0, l1, REXEE.swap_rep_pattern, ps_per_frame)
exit()

# 1-1. Plot the replica-sapce trajectory
print('1-1. Plotting transitions between state sets/replicas ...')
dt_swap = REXEE.nst_sim * REXEE.dt # dt for swapping replicas
Expand Down
19 changes: 19 additions & 0 deletions ensemble_md/utils/gmx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,22 @@ def read_top(file_name, resname):
if len(line_sep) > 4 and line_sep[3] == resname:
return input_file
raise Exception(f'Residue {resname} can not be found in {file_name}')

def get_end_states(mdp_path):
mdp = MDP(mdp_path)
end_0, end_1 = [], []
coul_lambda = mdp['coul_lambdas']
vdw_lambda = mdp['vdw_lambdas']
n = 0
for vdw, coul in zip(coul_lambda, vdw_lambda):
if vdw == 0.0 and coul == 0.0:
end_0.append(n)
elif vdw == 1.0 and coul == 1.0:
end_1.append(n)
n += 1
dt = mdp['dt']
steps_per_frame = mdp['nstxout']
ps_per_frame = dt*steps_per_frame

return end_0, end_1, ps_per_frame

0 comments on commit d28ded5

Please sign in to comment.