Skip to content

Commit

Permalink
Added stitch_trajs and renamed the original functions to stitch_time_…
Browse files Browse the repository at this point in the history
…series and stitch_time_series_for_sim
  • Loading branch information
wehs7661 committed Aug 19, 2023
1 parent 10ea945 commit 54e5264
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Ensemble Molecular Dynamics
[![GitHub Actions Lint Status](https://github.com/wehs7661/ensemble_md/actions/workflows/lint.yaml/badge.svg)](https://github.com/wehs7661/ensemble_md/actions/workflows/lint.yaml)
[![PyPI version](https://badge.fury.io/py/ensemble-md.svg)](https://badge.fury.io/py/ensemble-md)
[![DOI](https://img.shields.io/badge/DOI-arxiv.org%2Fabs%2F2308.06938-green)](https://arxiv.org/abs/2308.06938)

[![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/)
[![Downloads](https://static.pepy.tech/badge/ensemble-md)](https://pepy.tech/project/ensemble-md)
[![Twitter Follow](https://img.shields.io/twitter/follow/WeiTseHsu?style=social)](https://twitter.com/WeiTseHsu)
Expand Down
44 changes: 40 additions & 4 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def extract_state_traj(dhdl):
return traj


def stitch_trajs(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, save=True):
def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, save=True):
"""
Stitches the state-space/CV-space trajectories for each starting configuration from DHDL files
or PLUMED output files generated at different iterations.
Expand All @@ -49,7 +49,7 @@ def stitch_trajs(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, save=True
files : list
A list of lists of file names of GROMACS DHDL files or general GROMACS XVG files or PLUMED ouptput files.
Specifically, :code:`files[i]` should be a list containing the files of interest from all iterations in
replica :code:`i`.
replica :code:`i`. The files should be sorted naturally.
rep_trajs : list
A list of lists that represents the replica space trajectories for each starting configuration. For example,
:code:`rep_trajs[0] = [0, 2, 3, 0, 1, ...]` means that starting configuration 0 transitioned to replica 2, then
Expand Down Expand Up @@ -117,7 +117,7 @@ def stitch_trajs(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, save=True
return trajs


def stitch_trajs_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save=True):
def stitch_time_series_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save=True):
"""
Stitches the state-space/CV-space time series in the same replica/simulation folder.
That is, the output time series is contributed by multiple different trajectories (initiated by
Expand All @@ -128,7 +128,7 @@ def stitch_trajs_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save=True):
files : list
A list of lists of file names of GROMACS DHDL files or general GROMACS XVG files
or PLUMED output files. Specifically, :code:`files[i]` should be a list containing
the files of interest from all iterations in replica :code:`i`.
the files of interest from all iterations in replica :code:`i`. The files should be sorted naturally.
shifts : list
A list of values for shifting the state indices for each replica. The length of the list
should be equal to the number of replicas. This is only needed when :code:`dhdl=True`.
Expand Down Expand Up @@ -175,6 +175,42 @@ def stitch_trajs_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save=True):
return trajs


def stitch_trajs(gmx_executable, files, rep_trajs):
"""
Demuxes GROMACS trajectories from different replicas into individual continuous trajectories.
Parameters
----------
gmx_executable : str
The path to the GROMACS executable.
files : list
A list of lists of file names of GROMACS XTC files. Specifically, :code:`files[i]` should be a list containing
the files of interest from all iterations in replica :code:`i`. The files should be sorted naturally.
rep_trajs : list
A list of lists that represents the replica space trajectories for each starting configuration. For example,
:code:`rep_trajs[0] = [0, 2, 3, 0, 1, ...]` means that starting configuration 0 transitioned to replica 2, then
3, 0, 1, in iterations 1, 2, 3, 4, ..., respectively.
"""
n_sim = len(files) # number of replicas
n_iter = len(files[0]) # number of iterations per replica

# First figure out which xtc files each starting configuration corresponds to
# files_sorted[i] contains the xtc files for starting configuration i sorted
# based on iteration indices
files_sorted = [[] for i in range(n_sim)]
for i in range(n_sim):
for j in range(n_iter):
files_sorted[i].append(files[rep_trajs[i][j]][j])

# Then, stitch the trajectories for each starting configuration
for i in range(n_sim):
print(f'Concatenating trajectories initiated by configuration {i} ...')
arguments = [gmx_executable, 'trjcat', '-f']
arguments.extend(files_sorted[i])
arguments.extend(['-o', f'traj_{i}.xtc'])
utils.run_gmx_cmd(arguments)


def traj2transmtx(traj, N, normalize=True):
"""
Computes the transition matrix given a trajectory. For example, if a state-space
Expand Down
4 changes: 2 additions & 2 deletions ensemble_md/cli/analyze_EEXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def main():
print('2-1. Stitching trajectories for each starting configuration from dhdl files ...')
dhdl_files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(EEXE.n_sim)]
shifts = np.arange(EEXE.n_sim) * EEXE.s
state_trajs = analyze_traj.stitch_trajs(dhdl_files, rep_trajs, shifts=shifts, save=True) # length: the number of replicas # noqa: E501
state_trajs = analyze_traj.stitch_time_series(dhdl_files, rep_trajs, shifts=shifts, save=True) # length: the number of replicas # noqa: E501

# 2-2. Plot the state-space trajectory
print('\n2-2. Plotting transitions between different alchemical states ...')
Expand Down Expand Up @@ -188,7 +188,7 @@ def main():
print('2-4. Stitching time series of state index for each alchemical range ...')
dhdl_files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(EEXE.n_sim)]
shifts = np.arange(EEXE.n_sim) * EEXE.s
state_trajs_for_sim = analyze_traj.stitch_trajs_for_sim(dhdl_files, shifts)
state_trajs_for_sim = analyze_traj.stitch_time_series_for_sim(dhdl_files, shifts)

# 2-5. Plot the time series of state index for different alchemical ranges
print('\n2-5. Plotting the time series of state index for different alchemical ranges ...')
Expand Down
7 changes: 4 additions & 3 deletions ensemble_md/ensemble_EXE.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,8 @@ def combine_weights(self, weights, weights_err=None):

return weights, g_vec

def run_gmx_cmd(self, arguments):
@staticmethod
def run_gmx_cmd(arguments):
"""
Run a GROMACS command as a subprocess
Expand Down Expand Up @@ -1321,7 +1322,7 @@ def run_grompp(self, n, swap_pattern):
if rank == 0:
print('Generating TPR files ...')
if rank < self.n_sim:
returncode, stdout, stderr = self.run_gmx_cmd(args_list[rank])
returncode, stdout, stderr = EnsembleEXE.run_gmx_cmd(args_list[rank])
if returncode != 0:
print(f'Error on rank {rank} (return code: {returncode}):\n{stderr}')

Expand Down Expand Up @@ -1360,7 +1361,7 @@ def run_mdrun(self, n):
print('Running EXE simulations ...')
if rank < self.n_sim:
os.chdir(f'sim_{rank}/iteration_{n}')
returncode, stdout, stderr = self.run_gmx_cmd(arguments)
returncode, stdout, stderr = EnsembleEXE.run_gmx_cmd(arguments)
if returncode != 0:
print(f'Error on rank {rank} (return code: {returncode}):\n{stderr}')
if self.rm_cpt is True:
Expand Down

0 comments on commit 54e5264

Please sign in to comment.