Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replicas #36

Merged
merged 15 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 207 additions & 44 deletions Sample_Notebooks/BAR_Estimator_Basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,35 @@
"import os\n",
"from alchemlyb.parsing import namd\n",
"from IPython.display import display, Markdown\n",
"\n",
"from pathlib import Path\n",
"from dataclasses import dataclass\n",
"import scipy as sp\n",
"from alchemlyb.estimators import BAR\n",
"import warnings\n",
"warnings.simplefilter(action='ignore', category=FutureWarning)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf843fad",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"@dataclass\n",
"class FepRun:\n",
" u_nk: pd.DataFrame\n",
" perWindow: pd.DataFrame\n",
" cumulative: pd.DataFrame\n",
" forward: pd.DataFrame\n",
" forward_error: pd.DataFrame\n",
" backward: pd.DataFrame\n",
" backward_error: pd.DataFrame\n",
" per_lambda_convergence: pd.DataFrame\n",
" color: str"
]
},
{
"cell_type": "markdown",
"id": "a9d06104",
Expand All @@ -38,19 +61,25 @@
"metadata": {},
"outputs": [],
"source": [
"path='/path/to/data/'\n",
"filename='*.fepout'\n",
"dataroot = Path('.')\n",
"replica_pattern='Replica?'\n",
"replicas = dataroot.glob(replica_pattern)\n",
"filename_pattern='*.fepout'\n",
"\n",
"temperature = 303.15\n",
"RT = 0.00198720650096 * temperature\n",
"decorrelate = True #Flag for decorrelation of samples\n",
"detectEQ = True #Flag for automated equilibrium detection\n",
"\n",
"fepoutFiles = glob(path+filename)\n",
"totalSize = 0\n",
"for file in fepoutFiles:\n",
" totalSize += os.path.getsize(file)\n",
"print(f\"Will process {len(fepoutFiles)} fepout files.\\nTotal size:{np.round(totalSize/10**9, 2)}GB\")"
"detectEQ = True #Flag for automated equilibrium detection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6864313",
"metadata": {},
"outputs": [],
"source": [
"colors = ['blue', 'red', 'green', 'purple', 'orange', 'violet', 'cyan']\n",
"itcolors = iter(colors)"
]
},
{
Expand All @@ -64,27 +93,59 @@
"Note: alchemlyb operates in units of kT by default. We multiply by RT to convert to units of kcal/mol."
]
},
{
"cell_type": "markdown",
"id": "80b1034d",
"metadata": {},
"source": [
"# Read and plot number of samples after detecting EQ"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c747b75-8fa3-48f9-9ab3-ac4a12982c01",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"u_nk = namd.extract_u_nk(fepoutFiles, temperature)\n",
"safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')\n",
"\n",
"if detectEQ:\n",
" print(\"Detecting equilibrium\")\n",
" u_nk = safep.detect_equilibrium_u_nk(u_nk)\n",
" safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')\n",
"if decorrelate and not detectEQ:\n",
" print(\"Decorrelating\")\n",
" u_nk = safep.decorrelate_u_nk(u_nk)\n",
" safep.plot_samples(ax, u_nk, color='green', label='Decorrelated')\n",
"fepruns = {}\n",
"for replica in replicas:\n",
" print(f\"Reading {replica}\")\n",
" unkpath = replica.joinpath('decorrelated.csv')\n",
" u_nk = None\n",
" if unkpath.is_file():\n",
" print(f\"Found existing dataframe. Reading.\")\n",
" u_nk = safep.read_UNK(unkpath)\n",
" else:\n",
" print(f\"Didn't find existing dataframe at {unkpath}. Checking for raw fepout files.\")\n",
" fepoutFiles = list(replica.glob(filename_pattern))\n",
" totalSize = 0\n",
" for file in fepoutFiles:\n",
" totalSize += os.path.getsize(file)\n",
" print(f\"Will process {len(fepoutFiles)} fepout files.\\nTotal size:{np.round(totalSize/10**9, 2)}GB\")\n",
"\n",
" if len(list(fepoutFiles))>0:\n",
" print(\"Reading fepout files\")\n",
" fig, ax = plt.subplots()\n",
"\n",
" u_nk = namd.extract_u_nk(fepoutFiles, temperature)\n",
" u_nk = u_nk.sort_index(axis=0, level=1).sort_index(axis=1)\n",
" safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')\n",
"\n",
" if detectEQ:\n",
" print(\"Detecting equilibrium\")\n",
" u_nk = safep.detect_equilibrium_u_nk(u_nk)\n",
" safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')\n",
"\n",
" plt.savefig(f\"./{str(replica)}_FEP_number_of_samples.pdf\")\n",
" plt.show()\n",
" safep.save_UNK(u_nk, unkpath)\n",
" else:\n",
" print(f\"WARNING: no fepout files found for {replica}. Skipping.\")\n",
" \n",
"plt.savefig(f\"{path}FEP_number_of_samples.pdf\")"
" if u_nk is not None:\n",
" fepruns[str(replica)] = FepRun(u_nk, None, None, None, None, None, None, None, next(itcolors))\n",
" "
]
},
{
Expand All @@ -94,9 +155,20 @@
"metadata": {},
"outputs": [],
"source": [
"perWindow, cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data\n",
"forward, forward_error, backward, backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'\n",
"per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)"
"for key, feprun in fepruns.items():\n",
" u_nk = feprun.u_nk\n",
" feprun.perWindow, feprun.cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data\n",
" feprun.forward, feprun.forward_error, feprun.backward, feprun.backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'\n",
" feprun.per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "caae40d3",
"metadata": {},
"source": [
"# Plot data"
]
},
{
Expand All @@ -106,11 +178,59 @@
"metadata": {},
"outputs": [],
"source": [
"dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)\n",
"error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)\n",
"toprint = \"\"\n",
"dGs = []\n",
"errors = []\n",
"for key, feprun in fepruns.items():\n",
" cumulative = feprun.cumulative\n",
" dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)\n",
" error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)\n",
" dGs.append(dG)\n",
" errors.append(error)\n",
"\n",
" changeAndError = f'{key}: \\u0394G = {dG}\\u00B1{error} kcal/mol'\n",
" toprint += '<font size=5>{}</font><br/>'.format(changeAndError)\n",
"\n",
"toprint += '<font size=5>{}</font><br/>'.format('__________________')\n",
"mean = np.average(dGs)\n",
"\n",
"#If there are only a few replicas, the MBAR estimated error will be more reliable, albeit underestimated\n",
"if len(dGs)<3:\n",
" sterr = np.sqrt(np.sum(np.square(errors)))\n",
"else:\n",
" sterr = np.round(np.std(dGs),1)\n",
"toprint += '<font size=5>{}</font><br/>'.format(f'mean: {mean}')\n",
"toprint += '<font size=5>{}</font><br/>'.format(f'sterr: {sterr}')\n",
"Markdown(toprint)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1fde633",
"metadata": {},
"outputs": [],
"source": [
"def do_agg_data(dataax, plotax):\n",
" agg_data = []\n",
" lines = dataax.lines\n",
" for line in lines:\n",
" agg_data.append(line.get_ydata())\n",
" flat = np.array(agg_data).flatten()\n",
" kernel = sp.stats.gaussian_kde(flat)\n",
" pdfX = np.linspace(-1, 1, 1000)\n",
" pdfY = kernel(pdfX)\n",
" std = np.std(flat)\n",
" mean = np.average(flat)\n",
" temp = pd.Series(pdfY, index=pdfX)\n",
" mode = temp.idxmax()\n",
"\n",
" textstr = r\"$\\rm mode=$\"+f\"{np.round(mode,2)}\"+\"\\n\"+fr\"$\\mu$={np.round(mean,2)}\"+\"\\n\"+fr\"$\\sigma$={np.round(std,2)}\"\n",
" props = dict(boxstyle='square', facecolor='white', alpha=1)\n",
" plotax.text(0.175, 0.95, textstr, transform=plotax.transAxes, fontsize=14,\n",
" verticalalignment='top', bbox=props)\n",
"\n",
"changeAndError = f'\\u0394G = {dG}\\u00B1{error} kcal/mol'\n",
"Markdown('<font size=5>{}</font><br/>'.format(changeAndError))"
" return plotax"
]
},
{
Expand All @@ -120,9 +240,21 @@
"metadata": {},
"outputs": [],
"source": [
"fig, axes = safep.plot_general(cumulative, None, perWindow, None, RT)\n",
"fig.suptitle(changeAndError)\n",
"plt.savefig(f'{path}FEP_general_figures.pdf')"
"fig = None\n",
"for key, feprun in fepruns.items():\n",
" if fig is None:\n",
" fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)\n",
" axes[1].legend()\n",
" else:\n",
" fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=fig, axes=axes, hysttype='lines', label=key, color=feprun.color)\n",
" #fig.suptitle(changeAndError)\n",
"\n",
"# hack to get aggregate data:\n",
"axes[3] = do_agg_data(axes[2], axes[3])\n",
"\n",
"axes[0].set_title(str(mean)+r'$\\pm$'+str(sterr)+' kcal/mol')\n",
"axes[0].legend()\n",
"plt.savefig(dataroot.joinpath('FEP_general_figures.pdf'))"
]
},
{
Expand All @@ -141,28 +273,59 @@
"outputs": [],
"source": [
"fig, convAx = plt.subplots(1,1)\n",
"convAx = safep.convergence_plot(convAx, forward*RT, forward_error*RT, backward*RT, backward_error*RT)\n",
"plt.savefig(f'{path}FEP_convergence.pdf')"
"\n",
"for key, feprun in fepruns.items():\n",
" convAx = safep.convergence_plot(convAx, \n",
" feprun.forward*RT, \n",
" feprun.forward_error*RT, \n",
" feprun.backward*RT,\n",
" feprun.backward_error*RT,\n",
" fwd_color=feprun.color,\n",
" bwd_color=feprun.color,\n",
" errorbars=False\n",
" )\n",
" convAx.get_legend().remove()\n",
"\n",
"forward_line, = convAx.plot([],[],linestyle='-', color='black', label='Forward Time Sampling')\n",
"backward_line, = convAx.plot([],[],linestyle='--', color='black', label='Backward Time Sampling')\n",
"convAx.legend(handles=[forward_line, backward_line])\n",
"ymin = np.min(dGs)-1\n",
"ymax = np.max(dGs)+1\n",
"convAx.set_ylim((ymin,ymax))\n",
"plt.savefig(dataroot.joinpath('FEP_convergence.pdf'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bf6ac36-f102-4da3-82d2-36a7741584e5",
"id": "e6ab4621",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"ax.errorbar(per_lambda_convergence.index, per_lambda_convergence.BAR.df*RT)\n",
"ax.set_xlabel(r\"$\\lambda$\")\n",
"ax.set_ylabel(r\"$D_{last-first}$ (kcal/mol)\")\n",
"plt.savefig(f\"{path}FEP_perLambda_convergence.pdf\")"
"genfig = None\n",
"for key, feprun in fepruns.items():\n",
" if genfig is None:\n",
" genfig, genaxes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)\n",
" else:\n",
" genfig, genaxes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=genfig, axes=genaxes, hysttype='lines', label=key, color=feprun.color)\n",
"plt.delaxes(genaxes[0])\n",
"plt.delaxes(genaxes[1])\n",
"\n",
"genaxes[3] = do_agg_data(axes[2], axes[3])\n",
"genaxes[2].set_title(str(mean)+r'$\\pm$'+str(sterr)+' kcal/mol')\n",
"\n",
"for txt in genfig.texts:\n",
" print(1)\n",
" txt.set_visible(False)\n",
" txt.set_text(\"\")\n",
"plt.show()\n",
"plt.savefig(dataroot.joinpath('FEP_perLambda_convergence.pdf'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b237a8ee-f33e-4c05-b4ed-856bb4f34ca2",
"id": "96418460",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -184,7 +347,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions Sample_Notebooks/Replica1
1 change: 1 addition & 0 deletions Sample_Notebooks/Replica2
1 change: 1 addition & 0 deletions Sample_Notebooks/Replica3
6 changes: 3 additions & 3 deletions safep/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def plot_general(cumulative,

# Per-window change in kcal/mol
if errorbars:
each_ax.errorbar(per_window.index, per_window.BAR.df*RT, yerr=per_window.BAR.ddf, marker=None, linewidth=1, color=color)
each_ax.errorbar(per_window.index, -per_window.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5, linestyle='--', color=color)
each_ax.plot(per_window.index, per_window.EXP.dG_f*RT, marker=None, linewidth=1, alpha=0.5, color=color)
each_ax.errorbar(per_window.index, per_window.BAR.df*RT, yerr=per_window.BAR.ddf, marker=None, linewidth=1, alpha=0.5, color=color, label="forward EXP")
each_ax.errorbar(per_window.index, -per_window.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5, linestyle='--', color=color, label="backward EXP")
each_ax.plot(per_window.index, per_window.EXP.dG_f*RT, marker=None, linewidth=1, color=color, label="BAR")

each_ax.set(ylabel=r'$\mathrm{\Delta} G_\lambda$'+'\n'+r'$\left(kcal/mol\right)$', ylim=per_window_ylim)

Expand Down
Loading