Skip to content

Commit

Permalink
Modified compare_MDPs
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Aug 25, 2023
1 parent c94c80c commit 403f5fc
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions ensemble_md/utils/gmx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import six
import logging
import warnings
from itertools import combinations
from collections import OrderedDict as odict

from ensemble_md.utils import utils
Expand Down Expand Up @@ -386,7 +385,7 @@ def write(self, filename=None, skipempty=False):
mdp.write("{} = {}\n".format(k, " ".join(map(str, v))))


def compare_MDPs(mdp_list):
def compare_MDPs(mdp_list, print_diff=False):
"""
Given a list of MDP files, identify the parameters for which not all MDP
files have the same values. Note that this function is not aware of the default
Expand All @@ -396,31 +395,46 @@ def compare_MDPs(mdp_list):
Returns
-------
diff_params : list
The list of parameters differing between the input MDP files.
diff_params : dict
A dictionary of parameters that are different among the MDP files.
The keys are the parameter names and the values is a list of values of the
parameters in the MDP files.
print_diff : bool
If :code:`True`, print to screen the parameters that are different among the MDP files
and the values of the parameters in the MDP files in a more readable format.
"""
compare_list = list(combinations(mdp_list, r=2))
diff_params = []
for i in range(len(compare_list)):
params_1 = MDP(compare_list[i][0])
params_2 = MDP(compare_list[i][1])

mdp_1 = odict([(k.replace('-', '_'), v) if type(v) is str else (k.replace('-', '_'), v.replace('-', '_')) for k, v in params_1.items()]) # noqa: E501
mdp_2 = odict([(k.replace('-', '_'), v) if type(v) is str else (k.replace('-', '_'), v.replace('-', '_')) for k, v in params_2.items()]) # noqa: E501
diff_params = {}
for i in range(len(mdp_list)):
mdps = [MDP(mdp_list[i]) for i in range(len(mdp_list))]
params_dicts = [odict([(k.replace('-', '_'), v) if type(v) is not str else (k.replace('-', '_'), v.replace('-', '_')) for k, v in p.items()]) for p in mdps] # noqa: E501

# First figure out the union set of the parameters and exclude blanks and comments
all_params = set(list(mdp_1.keys()) + list(mdp_2.keys()))
all_params = set([key for d in params_dicts for key in d.keys()])
all_params = [p for p in all_params if not p.startswith(('B', 'C'))]

for p in all_params:
if p in diff_params:
pass # already in the list, no need to compare again
pass # already in the dictionary, no need to compare again
else:
if p not in mdp_1 or p not in mdp_2:
diff_params.append(p)
if not all(p in d for d in params_dicts):
# the parameter is not in all MDP files
diff_params[p] = [d[p] if p in d else None for d in params_dicts]
else:
# the parameter is in both MDP files
if mdp_1[p] != mdp_2[p]:
diff_params.append(p)
# the parameter is in all MDP files (Note that "set([1, 1, 1]={1}.)")
if isinstance(params_dicts[0][p], list):
# the parameter is a list, which is unhashable
if len(set([tuple(d[p]) for d in params_dicts])) > 1:
diff_params[p] = [d[p] for d in params_dicts]
else:
if len(set([d[p] for d in params_dicts])) > 1:
diff_params[p] = [d[p] for d in params_dicts]

if print_diff:
print("The following parameters are different among the MDP files:")
for k, v in diff_params.items():
print(k)
for i in range(len(mdp_list)):
print(f' - {mdp_list[i]}: {v[i]}')
print()

return diff_params

0 comments on commit 403f5fc

Please sign in to comment.