Skip to content

Commit

Permalink
Added convert_npy2xvg and moved run_gmx_cmd from EnsembleEXE to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Aug 19, 2023
1 parent 54e5264 commit c6979f8
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 42 deletions.
59 changes: 49 additions & 10 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ def extract_state_traj(dhdl):
-------
traj : list
A list that represents that state-space trajectory
t : list
A list that represents the time series of the trajectory
"""
traj = list(extract_dataframe(dhdl, headers=get_headers(dhdl))['Thermodynamic state'])
t = list(np.loadtxt(dhdl, comments=['#', '@'])[:, 0])

return traj
return traj, t


def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, save=True):
def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, save_npy=True, save_xvg=False):
"""
Stitches the state-space/CV-space trajectories for each starting configuration from DHDL files
or PLUMED output files generated at different iterations.
Expand Down Expand Up @@ -67,8 +70,11 @@ def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, sav
col_idx : int
The index of the column to be extracted from the input files. This is only needed when :code:`dhdl=False`,
By default, we extract the last column.
save : bool
save_npy : bool
Whether to save the output trajectories as an NPY file.
save_xvg : bool
Whether to save the time series for each trajectory as an XVG file. The first column is the time and the
second column is the state index.
Returns
-------
Expand All @@ -94,29 +100,62 @@ def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, sav
for j in range(n_iter):
if j == 0:
if dhdl:
traj = extract_state_traj(files_sorted[i][j])
traj, t = extract_state_traj(files_sorted[i][j])
else:
traj = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, col_idx]
t = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, 0] # only used if save_xvg is True
else:
# Starting from the 2nd iteration, we get rid of the first time frame the first
# frame of iteration n+1 the is the same as the last frame of iteration n
if dhdl:
traj = extract_state_traj(files_sorted[i][j])[1:]
traj, t = extract_state_traj(files_sorted[i][j])[1:]
else:
traj = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, col_idx][1:]
t = np.loadtxt(files_sorted[i][j], comments=['#', '@'])[:, 0][1:] # only used if save_xvg is True

if dhdl: # Trajectories of global alchemical indices will be generated.
shift_idx = rep_trajs[i][j]
traj = list(np.array(traj) + shifts[shift_idx])
trajs[i].extend(traj)

# Save the trajectories as an NPY file if desired
if save is True:
np.save('state_trajs.npy', trajs)
if save_npy is True:
if dhdl:
np.save('state_trajs.npy', trajs)
else:
np.save('cv_trajs.npy', trajs)

if save_xvg is True:
# This is probably only useful for clustering analysis
convert_npy2xvg(trajs, t[1] - t[0])

return trajs


def convert_npy2xvg(trajs, dt):
"""
Convert a :code:`state_trajs.npy` or :code:`cv_trajs.npy` file to :math:`N_{\text{rep}}` XVG files
that have two columns: time (ps) and state index.
Parameters
----------
trajs : ndarray
The state-space or CV-space trajectories read from :code:`state_trajs.npy` or :code:`cv_trajs.npy`.
dt : float
The time interval between consecutive frames of the trajectories.
"""
n_configs = len(trajs)
for i in range(n_configs):
traj = trajs[i]
t = np.arange(len(traj)) * dt
headers = ['This file was created by ensemble_md']
if 'int' in str(traj.dtype):
headers.extend(['Time (ps) v.s. State index'])
np.savetxt(f'traj_{i}.xvg', np.transpose([t, traj]), header='\n'.join(headers), fmt=['%-8.1f', '%4.0f'])
else:
headers.extend(['Time (ps) v.s. CV'])
np.savetxt(f'traj_{i}.xvg', np.transpose([t, traj]), header='\n'.join(headers), fmt=['%-8.1f', '%8.6f'])


def stitch_time_series_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save=True):
"""
Stitches the state-space/CV-space time series in the same replica/simulation folder.
Expand Down Expand Up @@ -157,7 +196,7 @@ def stitch_time_series_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save=T
for i in range(n_sim):
for j in range(n_iter):
if dhdl:
traj = extract_state_traj(files[i][j])
traj, _ = extract_state_traj(files[i][j])
else:
traj = np.loadtxt(files[i][j], comments=['#', '@'])[:, col_idx]

Expand Down Expand Up @@ -204,7 +243,7 @@ def stitch_trajs(gmx_executable, files, rep_trajs):

# Then, stitch the trajectories for each starting configuration
for i in range(n_sim):
print(f'Concatenating trajectories initiated by configuration {i} ...')
print(f'Recovering the continuous trajectory {i} by concatenating the XTC files ...')
arguments = [gmx_executable, 'trjcat', '-f']
arguments.extend(files_sorted[i])
arguments.extend(['-o', f'traj_{i}.xtc'])
Expand Down
2 changes: 1 addition & 1 deletion ensemble_md/cli/analyze_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def main():
print('2-1. Stitching trajectories for each starting configuration from dhdl files ...')
dhdl_files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(EEXE.n_sim)]
shifts = np.arange(EEXE.n_sim) * EEXE.s
state_trajs = analyze_traj.stitch_time_series(dhdl_files, rep_trajs, shifts=shifts, save=True) # length: the number of replicas # noqa: E501
state_trajs = analyze_traj.stitch_time_series(dhdl_files, rep_trajs, shifts=shifts, save_npy=True) # length: the number of replicas # noqa: E501

# 2-2. Plot the state-space trajectory
print('\n2-2. Plotting transitions between different alchemical states ...')
Expand Down
33 changes: 2 additions & 31 deletions ensemble_md/ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,35 +1241,6 @@ def combine_weights(self, weights, weights_err=None):

return weights, g_vec

@staticmethod
def run_gmx_cmd(arguments):
"""
Run a GROMACS command as a subprocess
Parameters
----------
arguments : list
A list of arguments that compose of the GROMACS command to run, e.g.
:code:`['gmx', 'mdrun', '-deffnm', 'sys']`.
Returns
-------
return_code : int
The exit code of the GROMACS command. Any number other than 0 indicates an error.
stdout : str or None
The STDOUT of the process.
stderr: str or None
The STDERR or the process.
"""
try:
result = subprocess.run(arguments, capture_output=True, text=True, check=True)
return_code, stdout, stderr = result.returncode, result.stdout, None
except subprocess.CalledProcessError as e:
return_code, stdout, stderr = e.returncode, None, e.stderr

return return_code, stdout, stderr

def run_grompp(self, n, swap_pattern):
"""
Prepares TPR files for the simulation ensemble using the GROMACS :code:`grompp` command.
Expand Down Expand Up @@ -1322,7 +1293,7 @@ def run_grompp(self, n, swap_pattern):
if rank == 0:
print('Generating TPR files ...')
if rank < self.n_sim:
returncode, stdout, stderr = EnsembleEXE.run_gmx_cmd(args_list[rank])
returncode, stdout, stderr = utils.run_gmx_cmd(args_list[rank])
if returncode != 0:
print(f'Error on rank {rank} (return code: {returncode}):\n{stderr}')

Expand Down Expand Up @@ -1361,7 +1332,7 @@ def run_mdrun(self, n):
print('Running EXE simulations ...')
if rank < self.n_sim:
os.chdir(f'sim_{rank}/iteration_{n}')
returncode, stdout, stderr = EnsembleEXE.run_gmx_cmd(arguments)
returncode, stdout, stderr = utils.run_gmx_cmd(arguments)
if returncode != 0:
print(f'Error on rank {rank} (return code: {returncode}):\n{stderr}')
if self.rm_cpt is True:
Expand Down
30 changes: 30 additions & 0 deletions ensemble_md/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import glob
import natsort
import datetime
import subprocess
import collections
import numpy as np
from itertools import combinations
Expand Down Expand Up @@ -71,6 +72,35 @@ def flush(self):
pass


def run_gmx_cmd(arguments):
"""
Run a GROMACS command as a subprocess
Parameters
----------
arguments : list
A list of arguments that compose of the GROMACS command to run, e.g.
:code:`['gmx', 'mdrun', '-deffnm', 'sys']`.
Returns
-------
return_code : int
The exit code of the GROMACS command. Any number other than 0 indicates an error.
stdout : str or None
The STDOUT of the process.
stderr: str or None
The STDERR or the process.
"""
try:
result = subprocess.run(arguments, capture_output=True, text=True, check=True)
return_code, stdout, stderr = result.returncode, result.stdout, None
except subprocess.CalledProcessError as e:
return_code, stdout, stderr = e.returncode, None, e.stderr

return return_code, stdout, stderr


def compare_MDPs(mdp_list):
"""
Given a list of MDP files, identify the parameters for which not all MDP
Expand Down

0 comments on commit c6979f8

Please sign in to comment.