Skip to content

Commit

Permalink
Developed clustering.py in ensemble_md.analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Aug 20, 2023
1 parent b25f4a1 commit ab230b4
Showing 1 changed file with 237 additions and 0 deletions.
237 changes: 237 additions & 0 deletions ensemble_md/analysis/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import numpy as np
from ensemble_md.utils.utils import run_gmx_cmd


def cluster_traj(gmx_executable, inputs, grps, method='linkage', cutoff=0.1, suffix=None):
"""
Performs clustering analysis on a trajectory using the GROMACS command :code:`gmx cluster`.
Note that only fully coupled configurations are considered.
Parameters
----------
gmx_executable : str
The path to the GROMACS executable.
inputs : dict
A dictionary that contains the file names of the input trajectory file (XTC or TRR),
the configuration file (TPR or GRO), the file that contains the time series of the
state index, and the index file (NDX). The must include the keys :code:`traj`, :code:`config`,
:code:`xvg`, and :code:`index`. Note that the value for the key :code:`index` can be :code:`None`.
grps : dict
A dictionary that contains the names of the groups in the index file (NDX) for
centering the system, calculating the RMSD, and outputting. The keys are
:code:`center`, :code:`rmsd`, and :code:`output`.
method : str
The method for clustering available for the GROMACS command :code:`gmx cluster`. The default is 'linkage'.
cutoff : float
The cutoff in RMSD for clustering. The default is 0.1.
suffix : str
The suffix for the output files. The default is :code:`None`, which means no suffix will be added.
"""
outputs = {
'nojump': 'nojump.xtc',
'center': 'center.xtc',
'rmsd-clust': 'rmsd-clust.xpm',
'rmsd-dist': 'rmsd-dist.xvg',
'cluster-log': 'cluster.log',
'cluster-pdb': 'clusters.pdb',
'rmsd': 'rmsd.xvg', # inter-medoid RMSD
}
if suffix is not None:
for key in outputs:
outputs[key] = outputs[key].replace('.', f'_{suffix}.')

print('Eliminating jumps across periodic boundaries for the input trajectory ...')
args = [
gmx_executable, 'trjconv',
'-f', inputs['traj'],
'-s', inputs['config'],
'-o', outputs['nojump'],
'-center', 'yes',
'-pbc', 'nojump'
'-drop', inputs['xvg'],
'-dropover', 0
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])
run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n')

print('Centering the system ...')
args = [
gmx_executable, 'trjconv',
'-f', outputs['nojump'],
'-s', inputs['config'],
'-o', outputs['center'],
'-center', 'yes',
'-pbc', 'mol',
'-ur', 'compact',
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])
run_gmx_cmd(args, prompt_input=f'{grps["center"]}\n{grps["output"]}\n')

print('Performing clustering analysis ...')
args = [
gmx_executable, 'cluster',
'-f', outputs['center'],
'-s', inputs['config'],
'-o', outputs['rmsd-clust'],
'-dist', outputs['rmsd-dist'],
'-g', outputs['cluster-log'],
'-cl', outputs['cluster-pdb'],
'-cutoff', cutoff,
'-method', method,
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])
run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["output"]}\n')

rmsd_range, rmsd_avg, n_clusters = get_cluster_info(outputs['cluster-log'])

print(f'Range of RMSD values: {rmsd_range} nm')
print(f'Average RMSD: {rmsd_avg} nm')
print(f'Number of clusters: {n_clusters}')

if n_clusters > 1:
clusters, sizes = get_cluster_members(outputs['cluster-log'])
for i in range(1, n_clusters + 1):
print(f' - Cluster {i} accounts for {sizes[i] * 100:.2f}% of the total configurations.')

n_transitions, t_transitions = count_transitions(clusters)
print(f'Number of transitions between the two biggest clusters: {n_transitions}')
print(f'Time frames of the transitions (ps): {t_transitions}')

print('Calculating the inter-medoid RMSD between the two biggest clusters ...')
args = [
gmx_executable, 'rms',
'-f', outputs['center'],
'-s', outputs['center'], # so that the first medoid will be used as the reference
'-o', outputs['rmsd'],
]
if inputs['index'] is not None:
args.extend(['-n', inputs['index']])

# Here we simply assume same groups for least-squares fitting and RMSD calculation
run_gmx_cmd(args, prompt_input=f'{grps["rmsd"]}\n{grps["rmsd"]}\n')

rmsd = np.transpose(np.loadtxt(outputs['rmsd']), comments=['@', '#'])[1][1] # inter-medoid RMSD
print('Inter-medoid RMSD between the two biggest clusters: ', rmsd)


def get_cluster_members(cluster_log):
"""
Gets the members of each cluster from the LOG file generated by the GROMACS :code:`gmx cluster` command.
Parameters
----------
cluster_log : str
The LOG file generated by the GROMACS :code:`gmx cluster` command.
Returns
-------
clusters : dict
A dictionary that contains the cluster index (starting from 1) as the key and the list of members
(configurations at different timeframes) as the value.
sizes : dict
A dictionary that contains the cluster index (starting from 1) as the key and the size of the cluster
(in fraction) as the value.
"""
clusters = {}
current_cluster = 0
start_processing = False

f = open(cluster_log, 'r')
lines = f.readlines()
f.close()

for line in lines:
# Start processing when we reach the line that starts with "cl."
if line.strip().startswith("cl."):
start_processing = True
continue # Skip this line and continue to the next iteration

if start_processing:
parts = line.split('|')
try:
current_cluster = int(parts[0].strip())
clusters[current_cluster] = []
except ValueError:
pass

# This is either a new cluster or continuation of it, add members
members = parts[-1].split()
clusters[current_cluster].extend([int(i) for i in members])

sizes_list = [len(clusters[i]) for i in clusters]
sizes = {i: sizes_list[i - 1] / sum(sizes_list) for i in clusters}

return clusters, sizes


def count_transitions(clusters):
"""
Counts the number of transitions between the two biggest clusters.
Parameters
----------
clusters : dict
A dictionary that contains the cluster index (starting from 1) as the key and the list of members
(configurations at different timeframes) as the value.
Returns
-------
n_transitions : int
The number of transitions between the two biggest clusters.
t_transitions : list
The list of time frames when the transitions occur.
"""
# Combine and sort all cluster members for the first two biggest clusters while keeping track of their origin
all_members = [(member, 1) for member in clusters[1]] + [(member, 2) for member in clusters[2]]
all_members.sort()

# Count transitions and record time frames
n_transitions = 0
t_transitions = []
last_cluster = all_members[0][1] # the cluster index of the last time frame in the previous iteration
for member in all_members[1:]:
if member[1] != last_cluster:
n_transitions += 1
last_cluster = member[1]
t_transitions.append(member[0])

return n_transitions, t_transitions


def get_cluster_info(cluster_log):
"""
Gets the metadata of the LOG file generated by the GROMACS :code:`gmx cluster` command.
Parameters
----------
cluster_log : str
The LOG file generated by the GROMACS :code:`gmx cluster` command.
Returns
-------
rmsd_range: list
The range of RMSD values
rmsd_avg: float
The average RMSD value.
n_clusters : int
The number of clusters.
"""
f = open(cluster_log, 'r')
lines = f.readlines()
f.close()

rmsd_range = []
for line in lines:
if 'The RMSE ranges from' in line:
rmsd_range.append(float(line.split('from')[-1].split('to')[0]))
rmsd_range.append(float(line.split('from')[-1].split('to')[-1].split('nm')[0]))
if 'Average RMSD' in line:
rmsd_avg = float(line.split('is')[-1])
if 'Found' in line:
n_clusters = int(line.split()[1])
break

return rmsd_range, rmsd_avg, n_clusters

0 comments on commit ab230b4

Please sign in to comment.