Skip to content

Commit

Permalink
Merge pull request #36 from BranniganLab/add_replicas
Browse files Browse the repository at this point in the history
Add replicas
  • Loading branch information
EzryStIago authored Jun 26, 2024
2 parents 44683f6 + 6f75ccc commit caa2a3a
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 58 deletions.
251 changes: 207 additions & 44 deletions Sample_Notebooks/BAR_Estimator_Basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,35 @@
"import os\n",
"from alchemlyb.parsing import namd\n",
"from IPython.display import display, Markdown\n",
"\n",
"from pathlib import Path\n",
"from dataclasses import dataclass\n",
"import scipy as sp\n",
"from alchemlyb.estimators import BAR\n",
"import warnings\n",
"warnings.simplefilter(action='ignore', category=FutureWarning)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf843fad",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"@dataclass\n",
"class FepRun:\n",
" u_nk: pd.DataFrame\n",
" perWindow: pd.DataFrame\n",
" cumulative: pd.DataFrame\n",
" forward: pd.DataFrame\n",
" forward_error: pd.DataFrame\n",
" backward: pd.DataFrame\n",
" backward_error: pd.DataFrame\n",
" per_lambda_convergence: pd.DataFrame\n",
" color: str"
]
},
{
"cell_type": "markdown",
"id": "a9d06104",
Expand All @@ -38,19 +61,25 @@
"metadata": {},
"outputs": [],
"source": [
"path='/path/to/data/'\n",
"filename='*.fepout'\n",
"dataroot = Path('.')\n",
"replica_pattern='Replica?'\n",
"replicas = dataroot.glob(replica_pattern)\n",
"filename_pattern='*.fepout'\n",
"\n",
"temperature = 303.15\n",
"RT = 0.00198720650096 * temperature\n",
"decorrelate = True #Flag for decorrelation of samples\n",
"detectEQ = True #Flag for automated equilibrium detection\n",
"\n",
"fepoutFiles = glob(path+filename)\n",
"totalSize = 0\n",
"for file in fepoutFiles:\n",
" totalSize += os.path.getsize(file)\n",
"print(f\"Will process {len(fepoutFiles)} fepout files.\\nTotal size:{np.round(totalSize/10**9, 2)}GB\")"
"detectEQ = True #Flag for automated equilibrium detection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6864313",
"metadata": {},
"outputs": [],
"source": [
"colors = ['blue', 'red', 'green', 'purple', 'orange', 'violet', 'cyan']\n",
"itcolors = iter(colors)"
]
},
{
Expand All @@ -64,27 +93,59 @@
"Note: alchemlyb operates in units of kT by default. We multiply by RT to convert to units of kcal/mol."
]
},
{
"cell_type": "markdown",
"id": "80b1034d",
"metadata": {},
"source": [
"# Read and plot number of samples after detecting EQ"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c747b75-8fa3-48f9-9ab3-ac4a12982c01",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"u_nk = namd.extract_u_nk(fepoutFiles, temperature)\n",
"safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')\n",
"\n",
"if detectEQ:\n",
" print(\"Detecting equilibrium\")\n",
" u_nk = safep.detect_equilibrium_u_nk(u_nk)\n",
" safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')\n",
"if decorrelate and not detectEQ:\n",
" print(\"Decorrelating\")\n",
" u_nk = safep.decorrelate_u_nk(u_nk)\n",
" safep.plot_samples(ax, u_nk, color='green', label='Decorrelated')\n",
"fepruns = {}\n",
"for replica in replicas:\n",
" print(f\"Reading {replica}\")\n",
" unkpath = replica.joinpath('decorrelated.csv')\n",
" u_nk = None\n",
" if unkpath.is_file():\n",
" print(f\"Found existing dataframe. Reading.\")\n",
" u_nk = safep.read_UNK(unkpath)\n",
" else:\n",
" print(f\"Didn't find existing dataframe at {unkpath}. Checking for raw fepout files.\")\n",
" fepoutFiles = list(replica.glob(filename_pattern))\n",
" totalSize = 0\n",
" for file in fepoutFiles:\n",
" totalSize += os.path.getsize(file)\n",
" print(f\"Will process {len(fepoutFiles)} fepout files.\\nTotal size:{np.round(totalSize/10**9, 2)}GB\")\n",
"\n",
" if len(list(fepoutFiles))>0:\n",
" print(\"Reading fepout files\")\n",
" fig, ax = plt.subplots()\n",
"\n",
" u_nk = namd.extract_u_nk(fepoutFiles, temperature)\n",
" u_nk = u_nk.sort_index(axis=0, level=1).sort_index(axis=1)\n",
" safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')\n",
"\n",
" if detectEQ:\n",
" print(\"Detecting equilibrium\")\n",
" u_nk = safep.detect_equilibrium_u_nk(u_nk)\n",
" safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')\n",
"\n",
" plt.savefig(f\"./{str(replica)}_FEP_number_of_samples.pdf\")\n",
" plt.show()\n",
" safep.save_UNK(u_nk, unkpath)\n",
" else:\n",
" print(f\"WARNING: no fepout files found for {replica}. Skipping.\")\n",
" \n",
"plt.savefig(f\"{path}FEP_number_of_samples.pdf\")"
" if u_nk is not None:\n",
" fepruns[str(replica)] = FepRun(u_nk, None, None, None, None, None, None, None, next(itcolors))\n",
" "
]
},
{
Expand All @@ -94,9 +155,20 @@
"metadata": {},
"outputs": [],
"source": [
"perWindow, cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data\n",
"forward, forward_error, backward, backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'\n",
"per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)"
"for key, feprun in fepruns.items():\n",
" u_nk = feprun.u_nk\n",
" feprun.perWindow, feprun.cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data\n",
" feprun.forward, feprun.forward_error, feprun.backward, feprun.backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'\n",
" feprun.per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "caae40d3",
"metadata": {},
"source": [
"# Plot data"
]
},
{
Expand All @@ -106,11 +178,59 @@
"metadata": {},
"outputs": [],
"source": [
"dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)\n",
"error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)\n",
"toprint = \"\"\n",
"dGs = []\n",
"errors = []\n",
"for key, feprun in fepruns.items():\n",
" cumulative = feprun.cumulative\n",
" dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)\n",
" error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)\n",
" dGs.append(dG)\n",
" errors.append(error)\n",
"\n",
" changeAndError = f'{key}: \\u0394G = {dG}\\u00B1{error} kcal/mol'\n",
" toprint += '<font size=5>{}</font><br/>'.format(changeAndError)\n",
"\n",
"toprint += '<font size=5>{}</font><br/>'.format('__________________')\n",
"mean = np.average(dGs)\n",
"\n",
"#If there are only a few replicas, the MBAR estimated error will be more reliable, albeit underestimated\n",
"if len(dGs)<3:\n",
" sterr = np.sqrt(np.sum(np.square(errors)))\n",
"else:\n",
" sterr = np.round(np.std(dGs),1)\n",
"toprint += '<font size=5>{}</font><br/>'.format(f'mean: {mean}')\n",
"toprint += '<font size=5>{}</font><br/>'.format(f'sterr: {sterr}')\n",
"Markdown(toprint)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1fde633",
"metadata": {},
"outputs": [],
"source": [
"def do_agg_data(dataax, plotax):\n",
" agg_data = []\n",
" lines = dataax.lines\n",
" for line in lines:\n",
" agg_data.append(line.get_ydata())\n",
" flat = np.array(agg_data).flatten()\n",
" kernel = sp.stats.gaussian_kde(flat)\n",
" pdfX = np.linspace(-1, 1, 1000)\n",
" pdfY = kernel(pdfX)\n",
" std = np.std(flat)\n",
" mean = np.average(flat)\n",
" temp = pd.Series(pdfY, index=pdfX)\n",
" mode = temp.idxmax()\n",
"\n",
" textstr = r\"$\\rm mode=$\"+f\"{np.round(mode,2)}\"+\"\\n\"+fr\"$\\mu$={np.round(mean,2)}\"+\"\\n\"+fr\"$\\sigma$={np.round(std,2)}\"\n",
" props = dict(boxstyle='square', facecolor='white', alpha=1)\n",
" plotax.text(0.175, 0.95, textstr, transform=plotax.transAxes, fontsize=14,\n",
" verticalalignment='top', bbox=props)\n",
"\n",
"changeAndError = f'\\u0394G = {dG}\\u00B1{error} kcal/mol'\n",
"Markdown('<font size=5>{}</font><br/>'.format(changeAndError))"
" return plotax"
]
},
{
Expand All @@ -120,9 +240,21 @@
"metadata": {},
"outputs": [],
"source": [
"fig, axes = safep.plot_general(cumulative, None, perWindow, None, RT)\n",
"fig.suptitle(changeAndError)\n",
"plt.savefig(f'{path}FEP_general_figures.pdf')"
"fig = None\n",
"for key, feprun in fepruns.items():\n",
" if fig is None:\n",
" fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)\n",
" axes[1].legend()\n",
" else:\n",
" fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=fig, axes=axes, hysttype='lines', label=key, color=feprun.color)\n",
" #fig.suptitle(changeAndError)\n",
"\n",
"# hack to get aggregate data:\n",
"axes[3] = do_agg_data(axes[2], axes[3])\n",
"\n",
"axes[0].set_title(str(mean)+r'$\\pm$'+str(sterr)+' kcal/mol')\n",
"axes[0].legend()\n",
"plt.savefig(dataroot.joinpath('FEP_general_figures.pdf'))"
]
},
{
Expand All @@ -141,28 +273,59 @@
"outputs": [],
"source": [
"fig, convAx = plt.subplots(1,1)\n",
"convAx = safep.convergence_plot(convAx, forward*RT, forward_error*RT, backward*RT, backward_error*RT)\n",
"plt.savefig(f'{path}FEP_convergence.pdf')"
"\n",
"for key, feprun in fepruns.items():\n",
" convAx = safep.convergence_plot(convAx, \n",
" feprun.forward*RT, \n",
" feprun.forward_error*RT, \n",
" feprun.backward*RT,\n",
" feprun.backward_error*RT,\n",
" fwd_color=feprun.color,\n",
" bwd_color=feprun.color,\n",
" errorbars=False\n",
" )\n",
" convAx.get_legend().remove()\n",
"\n",
"forward_line, = convAx.plot([],[],linestyle='-', color='black', label='Forward Time Sampling')\n",
"backward_line, = convAx.plot([],[],linestyle='--', color='black', label='Backward Time Sampling')\n",
"convAx.legend(handles=[forward_line, backward_line])\n",
"ymin = np.min(dGs)-1\n",
"ymax = np.max(dGs)+1\n",
"convAx.set_ylim((ymin,ymax))\n",
"plt.savefig(dataroot.joinpath('FEP_convergence.pdf'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bf6ac36-f102-4da3-82d2-36a7741584e5",
"id": "e6ab4621",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"ax.errorbar(per_lambda_convergence.index, per_lambda_convergence.BAR.df*RT)\n",
"ax.set_xlabel(r\"$\\lambda$\")\n",
"ax.set_ylabel(r\"$D_{last-first}$ (kcal/mol)\")\n",
"plt.savefig(f\"{path}FEP_perLambda_convergence.pdf\")"
"genfig = None\n",
"for key, feprun in fepruns.items():\n",
" if genfig is None:\n",
" genfig, genaxes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)\n",
" else:\n",
" genfig, genaxes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=genfig, axes=genaxes, hysttype='lines', label=key, color=feprun.color)\n",
"plt.delaxes(genaxes[0])\n",
"plt.delaxes(genaxes[1])\n",
"\n",
"genaxes[3] = do_agg_data(axes[2], axes[3])\n",
"genaxes[2].set_title(str(mean)+r'$\\pm$'+str(sterr)+' kcal/mol')\n",
"\n",
"for txt in genfig.texts:\n",
" print(1)\n",
" txt.set_visible(False)\n",
" txt.set_text(\"\")\n",
"plt.show()\n",
"plt.savefig(dataroot.joinpath('FEP_perLambda_convergence.pdf'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b237a8ee-f33e-4c05-b4ed-856bb4f34ca2",
"id": "96418460",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -184,7 +347,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions Sample_Notebooks/Replica1
1 change: 1 addition & 0 deletions Sample_Notebooks/Replica2
1 change: 1 addition & 0 deletions Sample_Notebooks/Replica3
6 changes: 3 additions & 3 deletions safep/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def plot_general(cumulative,

# Per-window change in kcal/mol
if errorbars:
each_ax.errorbar(per_window.index, per_window.BAR.df*RT, yerr=per_window.BAR.ddf, marker=None, linewidth=1, color=color)
each_ax.errorbar(per_window.index, -per_window.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5, linestyle='--', color=color)
each_ax.plot(per_window.index, per_window.EXP.dG_f*RT, marker=None, linewidth=1, alpha=0.5, color=color)
each_ax.errorbar(per_window.index, per_window.BAR.df*RT, yerr=per_window.BAR.ddf, marker=None, linewidth=1, alpha=0.5, color=color, label="forward EXP")
each_ax.errorbar(per_window.index, -per_window.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5, linestyle='--', color=color, label="backward EXP")
each_ax.plot(per_window.index, per_window.EXP.dG_f*RT, marker=None, linewidth=1, color=color, label="BAR")

each_ax.set(ylabel=r'$\mathrm{\Delta} G_\lambda$'+'\n'+r'$\left(\mathrm{kcal}/\mathrm{mol}\right)$', ylim=per_window_ylim)

Expand Down
Loading

0 comments on commit caa2a3a

Please sign in to comment.