Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a script for comparing distances from CPPTraj output #56

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyqmmm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def cli():
@click.option("--rmsf", "-rmsf", is_flag=True, help="Calculates the RMSF.")
@click.option("--quick_csa", "-csa", is_flag=True, help="Performs charge shift analysis.")
@click.option("--cc_coupling", "-cc", is_flag=True, help="Plots the results from cc coupling analysis.")
@click.option("--compare_distances", "-cd", is_flag=True, help="Plots distance metrics together.")
@click.help_option('--help', '-h', is_flag=True, help='Exiting pyQMMM.')
def md(
gbsa_submit,
Expand All @@ -62,6 +63,7 @@ def md(
rmsf,
quick_csa,
cc_coupling,
compare_distances,
):
"""
Functions for molecular dynamics (MD) simulations.
Expand Down Expand Up @@ -181,6 +183,11 @@ def md(
out_file="matrix_geom",
)

elif compare_distances:
import pyqmmm.md.compare_distances
files = input("What distance files would you like to plot? ").split(",")
pyqmmm.md.compare_distances.get_plot(files)


@cli.command()
@click.option("--plot_energy", "-pe", is_flag=True, help="Plot the energy of a xyz traj.")
Expand Down
77 changes: 77 additions & 0 deletions pyqmmm/md/compare_distances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import re
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def format_plot() -> None:
"""
General plotting parameters for the Kulik Lab.
"""
font = {"family": "sans-serif", "weight": "bold", "size": 10}
plt.rc("font", **font)
plt.rcParams["xtick.major.pad"] = 5
plt.rcParams["ytick.major.pad"] = 5
plt.rcParams["axes.linewidth"] = 2
plt.rcParams["xtick.major.size"] = 7
plt.rcParams["xtick.major.width"] = 2
plt.rcParams["ytick.major.size"] = 7
plt.rcParams["ytick.major.width"] = 2
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.top"] = True
plt.rcParams["ytick.right"] = True
plt.rcParams["svg.fonttype"] = "none"

def read_data(file_name):
# Read the second column from the file
return np.loadtxt(file_name, usecols=[1], skiprows=1)

def get_legend_labels(file):
atoms = re.split("[.-]", file)[1]
legend = f"{atoms[0]}···{atoms[1]}"

return legend


def get_colors(files):
# Assign the color palette
if len(files) == 1:
colors = ['#08415c']
elif len(files) == 2:
colors = ['#cc2936', '#08415c']
elif len(files) == 3:
colors = ['#cc2936', '#08415c', "#ABABAB"]

return colors


def get_plot(files):

colors = get_colors(files)
legend = []
for index,file in enumerate(files):

# Read data from the files
data_NC = read_data(file)

format_plot()

# Create a histogram with a KDE line for NC data
bin_count = 150
label = get_legend_labels(file)
legend.append(label)
sns.histplot(data_NC, bins=bin_count, kde=True, color=colors[index], linewidth=0, alpha=0.55, label=label)

# Add labels and title if desired
plt.xlabel('distance (Å)', fontsize=10, weight="bold")
plt.ylabel('frequency', fontsize=10, weight="bold")
plt.legend(frameon=False)

extensions = ["png", "svg"]
out_name = "_".join(legend)
for ext in extensions:
plt.savefig(f"{out_name}.{ext}", bbox_inches="tight", format=ext, dpi=600)

if __name__ == "__main__":
files = input("What distance files would you like to plot? ").split(",")
get_plot(files)
Loading