diff --git a/Sample_Notebooks/BAR_Estimator_Basic.ipynb b/Sample_Notebooks/BAR_Estimator_Basic.ipynb
index abe4375..7896a58 100644
--- a/Sample_Notebooks/BAR_Estimator_Basic.ipynb
+++ b/Sample_Notebooks/BAR_Estimator_Basic.ipynb
@@ -15,12 +15,35 @@
"import os\n",
"from alchemlyb.parsing import namd\n",
"from IPython.display import display, Markdown\n",
- "\n",
+ "from pathlib import Path\n",
+ "from dataclasses import dataclass\n",
+ "import scipy as sp\n",
"from alchemlyb.estimators import BAR\n",
"import warnings\n",
"warnings.simplefilter(action='ignore', category=FutureWarning)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bf843fad",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "@dataclass\n",
+ "class FepRun:\n",
+ " u_nk: pd.DataFrame\n",
+ " perWindow: pd.DataFrame\n",
+ " cumulative: pd.DataFrame\n",
+ " forward: pd.DataFrame\n",
+ " forward_error: pd.DataFrame\n",
+ " backward: pd.DataFrame\n",
+ " backward_error: pd.DataFrame\n",
+ " per_lambda_convergence: pd.DataFrame\n",
+ " color: str"
+ ]
+ },
{
"cell_type": "markdown",
"id": "a9d06104",
@@ -38,19 +61,25 @@
"metadata": {},
"outputs": [],
"source": [
- "path='/path/to/data/'\n",
- "filename='*.fepout'\n",
+ "dataroot = Path('.')\n",
+ "replica_pattern='Replica?'\n",
+ "replicas = dataroot.glob(replica_pattern)\n",
+ "filename_pattern='*.fepout'\n",
"\n",
"temperature = 303.15\n",
"RT = 0.00198720650096 * temperature\n",
- "decorrelate = True #Flag for decorrelation of samples\n",
- "detectEQ = True #Flag for automated equilibrium detection\n",
- "\n",
- "fepoutFiles = glob(path+filename)\n",
- "totalSize = 0\n",
- "for file in fepoutFiles:\n",
- " totalSize += os.path.getsize(file)\n",
- "print(f\"Will process {len(fepoutFiles)} fepout files.\\nTotal size:{np.round(totalSize/10**9, 2)}GB\")"
+ "detectEQ = True #Flag for automated equilibrium detection"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a6864313",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "colors = ['blue', 'red', 'green', 'purple', 'orange', 'violet', 'cyan']\n",
+ "itcolors = iter(colors)"
]
},
{
@@ -64,6 +93,14 @@
"Note: alchemlyb operates in units of kT by default. We multiply by RT to convert to units of kcal/mol."
]
},
+ {
+ "cell_type": "markdown",
+ "id": "80b1034d",
+ "metadata": {},
+ "source": [
+ "# Read and plot number of samples after detecting EQ"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -71,20 +108,44 @@
"metadata": {},
"outputs": [],
"source": [
- "fig, ax = plt.subplots()\n",
- "u_nk = namd.extract_u_nk(fepoutFiles, temperature)\n",
- "safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')\n",
- "\n",
- "if detectEQ:\n",
- " print(\"Detecting equilibrium\")\n",
- " u_nk = safep.detect_equilibrium_u_nk(u_nk)\n",
- " safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')\n",
- "if decorrelate and not detectEQ:\n",
- " print(\"Decorrelating\")\n",
- " u_nk = safep.decorrelate_u_nk(u_nk)\n",
- " safep.plot_samples(ax, u_nk, color='green', label='Decorrelated')\n",
+ "fepruns = {}\n",
+ "for replica in replicas:\n",
+ " print(f\"Reading {replica}\")\n",
+ " unkpath = replica.joinpath('decorrelated.csv')\n",
+ " u_nk = None\n",
+ " if unkpath.is_file():\n",
+ " print(f\"Found existing dataframe. Reading.\")\n",
+ " u_nk = safep.read_UNK(unkpath)\n",
+ " else:\n",
+ " print(f\"Didn't find existing dataframe at {unkpath}. Checking for raw fepout files.\")\n",
+ " fepoutFiles = list(replica.glob(filename_pattern))\n",
+ " totalSize = 0\n",
+ " for file in fepoutFiles:\n",
+ " totalSize += os.path.getsize(file)\n",
+ " print(f\"Will process {len(fepoutFiles)} fepout files.\\nTotal size:{np.round(totalSize/10**9, 2)}GB\")\n",
+ "\n",
+ " if len(list(fepoutFiles))>0:\n",
+ " print(\"Reading fepout files\")\n",
+ " fig, ax = plt.subplots()\n",
+ "\n",
+ " u_nk = namd.extract_u_nk(fepoutFiles, temperature)\n",
+ " u_nk = u_nk.sort_index(axis=0, level=1).sort_index(axis=1)\n",
+ " safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')\n",
+ "\n",
+ " if detectEQ:\n",
+ " print(\"Detecting equilibrium\")\n",
+ " u_nk = safep.detect_equilibrium_u_nk(u_nk)\n",
+ " safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')\n",
+ "\n",
+ " plt.savefig(f\"./{str(replica)}_FEP_number_of_samples.pdf\")\n",
+ " plt.show()\n",
+ " safep.save_UNK(u_nk, unkpath)\n",
+ " else:\n",
+ " print(f\"WARNING: no fepout files found for {replica}. Skipping.\")\n",
" \n",
- "plt.savefig(f\"{path}FEP_number_of_samples.pdf\")"
+ " if u_nk is not None:\n",
+ " fepruns[str(replica)] = FepRun(u_nk, None, None, None, None, None, None, None, next(itcolors))\n",
+ " "
]
},
{
@@ -94,9 +155,20 @@
"metadata": {},
"outputs": [],
"source": [
- "perWindow, cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data\n",
- "forward, forward_error, backward, backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'\n",
- "per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)"
+ "for key, feprun in fepruns.items():\n",
+ " u_nk = feprun.u_nk\n",
+ " feprun.perWindow, feprun.cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data\n",
+ " feprun.forward, feprun.forward_error, feprun.backward, feprun.backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'\n",
+ " feprun.per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "caae40d3",
+ "metadata": {},
+ "source": [
+ "# Plot data"
]
},
{
@@ -106,11 +178,59 @@
"metadata": {},
"outputs": [],
"source": [
- "dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)\n",
- "error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)\n",
+ "toprint = \"\"\n",
+ "dGs = []\n",
+ "errors = []\n",
+ "for key, feprun in fepruns.items():\n",
+ " cumulative = feprun.cumulative\n",
+ " dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)\n",
+ " error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)\n",
+ " dGs.append(dG)\n",
+ " errors.append(error)\n",
+ "\n",
+ " changeAndError = f'{key}: \\u0394G = {dG}\\u00B1{error} kcal/mol'\n",
+ " toprint += '{}
'.format(changeAndError)\n",
+ "\n",
+ "toprint += '{}
'.format('__________________')\n",
+ "mean = np.average(dGs)\n",
+ "\n",
+ "#If there are only a few replicas, the MBAR estimated error will be more reliable, albeit underestimated\n",
+ "if len(dGs)<3:\n",
+ " sterr = np.sqrt(np.sum(np.square(errors)))\n",
+ "else:\n",
+ " sterr = np.round(np.std(dGs),1)\n",
+ "toprint += '{}
'.format(f'mean: {mean}')\n",
+ "toprint += '{}
'.format(f'sterr: {sterr}')\n",
+ "Markdown(toprint)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d1fde633",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def do_agg_data(dataax, plotax):\n",
+ " agg_data = []\n",
+ " lines = dataax.lines\n",
+ " for line in lines:\n",
+ " agg_data.append(line.get_ydata())\n",
+ " flat = np.array(agg_data).flatten()\n",
+ " kernel = sp.stats.gaussian_kde(flat)\n",
+ " pdfX = np.linspace(-1, 1, 1000)\n",
+ " pdfY = kernel(pdfX)\n",
+ " std = np.std(flat)\n",
+ " mean = np.average(flat)\n",
+ " temp = pd.Series(pdfY, index=pdfX)\n",
+ " mode = temp.idxmax()\n",
+ "\n",
+ " textstr = r\"$\\rm mode=$\"+f\"{np.round(mode,2)}\"+\"\\n\"+fr\"$\\mu$={np.round(mean,2)}\"+\"\\n\"+fr\"$\\sigma$={np.round(std,2)}\"\n",
+ " props = dict(boxstyle='square', facecolor='white', alpha=1)\n",
+ " plotax.text(0.175, 0.95, textstr, transform=plotax.transAxes, fontsize=14,\n",
+ " verticalalignment='top', bbox=props)\n",
"\n",
- "changeAndError = f'\\u0394G = {dG}\\u00B1{error} kcal/mol'\n",
- "Markdown('{}
'.format(changeAndError))"
+ " return plotax"
]
},
{
@@ -120,9 +240,21 @@
"metadata": {},
"outputs": [],
"source": [
- "fig, axes = safep.plot_general(cumulative, None, perWindow, None, RT)\n",
- "fig.suptitle(changeAndError)\n",
- "plt.savefig(f'{path}FEP_general_figures.pdf')"
+ "fig = None\n",
+ "for key, feprun in fepruns.items():\n",
+ " if fig is None:\n",
+ " fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)\n",
+ " axes[1].legend()\n",
+ " else:\n",
+ " fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=fig, axes=axes, hysttype='lines', label=key, color=feprun.color)\n",
+ " #fig.suptitle(changeAndError)\n",
+ "\n",
+ "# hack to get aggregate data:\n",
+ "axes[3] = do_agg_data(axes[2], axes[3])\n",
+ "\n",
+ "axes[0].set_title(str(mean)+r'$\\pm$'+str(sterr)+' kcal/mol')\n",
+ "axes[0].legend()\n",
+ "plt.savefig(dataroot.joinpath('FEP_general_figures.pdf'))"
]
},
{
@@ -141,28 +273,59 @@
"outputs": [],
"source": [
"fig, convAx = plt.subplots(1,1)\n",
- "convAx = safep.convergence_plot(convAx, forward*RT, forward_error*RT, backward*RT, backward_error*RT)\n",
- "plt.savefig(f'{path}FEP_convergence.pdf')"
+ "\n",
+ "for key, feprun in fepruns.items():\n",
+ " convAx = safep.convergence_plot(convAx, \n",
+ " feprun.forward*RT, \n",
+ " feprun.forward_error*RT, \n",
+ " feprun.backward*RT,\n",
+ " feprun.backward_error*RT,\n",
+ " fwd_color=feprun.color,\n",
+ " bwd_color=feprun.color,\n",
+ " errorbars=False\n",
+ " )\n",
+ " convAx.get_legend().remove()\n",
+ "\n",
+ "forward_line, = convAx.plot([],[],linestyle='-', color='black', label='Forward Time Sampling')\n",
+ "backward_line, = convAx.plot([],[],linestyle='--', color='black', label='Backward Time Sampling')\n",
+ "convAx.legend(handles=[forward_line, backward_line])\n",
+ "ymin = np.min(dGs)-1\n",
+ "ymax = np.max(dGs)+1\n",
+ "convAx.set_ylim((ymin,ymax))\n",
+ "plt.savefig(dataroot.joinpath('FEP_convergence.pdf'))"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "1bf6ac36-f102-4da3-82d2-36a7741584e5",
+ "id": "e6ab4621",
"metadata": {},
"outputs": [],
"source": [
- "fig, ax = plt.subplots()\n",
- "ax.errorbar(per_lambda_convergence.index, per_lambda_convergence.BAR.df*RT)\n",
- "ax.set_xlabel(r\"$\\lambda$\")\n",
- "ax.set_ylabel(r\"$D_{last-first}$ (kcal/mol)\")\n",
- "plt.savefig(f\"{path}FEP_perLambda_convergence.pdf\")"
+ "genfig = None\n",
+ "for key, feprun in fepruns.items():\n",
+ " if genfig is None:\n",
+ " genfig, genaxes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)\n",
+ " else:\n",
+ " genfig, genaxes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=genfig, axes=genaxes, hysttype='lines', label=key, color=feprun.color)\n",
+ "plt.delaxes(genaxes[0])\n",
+ "plt.delaxes(genaxes[1])\n",
+ "\n",
+ "genaxes[3] = do_agg_data(axes[2], axes[3])\n",
+ "genaxes[2].set_title(str(mean)+r'$\\pm$'+str(sterr)+' kcal/mol')\n",
+ "\n",
+ "for txt in genfig.texts:\n",
+ " print(1)\n",
+ " txt.set_visible(False)\n",
+ " txt.set_text(\"\")\n",
+ "plt.show()\n",
+ "plt.savefig(dataroot.joinpath('FEP_perLambda_convergence.pdf'))"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "b237a8ee-f33e-4c05-b4ed-856bb4f34ca2",
+ "id": "96418460",
"metadata": {},
"outputs": [],
"source": []
@@ -184,7 +347,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.12"
+ "version": "3.12.3"
}
},
"nbformat": 4,
diff --git a/Sample_Notebooks/Replica1 b/Sample_Notebooks/Replica1
new file mode 120000
index 0000000..df9f7a8
--- /dev/null
+++ b/Sample_Notebooks/Replica1
@@ -0,0 +1 @@
+Sample_Data
\ No newline at end of file
diff --git a/Sample_Notebooks/Replica2 b/Sample_Notebooks/Replica2
new file mode 120000
index 0000000..df9f7a8
--- /dev/null
+++ b/Sample_Notebooks/Replica2
@@ -0,0 +1 @@
+Sample_Data
\ No newline at end of file
diff --git a/Sample_Notebooks/Replica3 b/Sample_Notebooks/Replica3
new file mode 120000
index 0000000..df9f7a8
--- /dev/null
+++ b/Sample_Notebooks/Replica3
@@ -0,0 +1 @@
+Sample_Data
\ No newline at end of file
diff --git a/safep/plotting.py b/safep/plotting.py
index e377f19..a3471cf 100644
--- a/safep/plotting.py
+++ b/safep/plotting.py
@@ -187,9 +187,9 @@ def plot_general(cumulative,
# Per-window change in kcal/mol
if errorbars:
- each_ax.errorbar(per_window.index, per_window.BAR.df*RT, yerr=per_window.BAR.ddf, marker=None, linewidth=1, color=color)
- each_ax.errorbar(per_window.index, -per_window.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5, linestyle='--', color=color)
- each_ax.plot(per_window.index, per_window.EXP.dG_f*RT, marker=None, linewidth=1, alpha=0.5, color=color)
+ each_ax.errorbar(per_window.index, per_window.BAR.df*RT, yerr=per_window.BAR.ddf, marker=None, linewidth=1, alpha=0.5, color=color, label="forward EXP")
+ each_ax.errorbar(per_window.index, -per_window.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5, linestyle='--', color=color, label="backward EXP")
+ each_ax.plot(per_window.index, per_window.EXP.dG_f*RT, marker=None, linewidth=1, color=color, label="BAR")
each_ax.set(ylabel=r'$\mathrm{\Delta} G_\lambda$'+'\n'+r'$\left(kcal/mol\right)$', ylim=per_window_ylim)
diff --git a/safep/processing.py b/safep/processing.py
index 1c6c58a..c9222d6 100644
--- a/safep/processing.py
+++ b/safep/processing.py
@@ -195,12 +195,13 @@ def alt_convergence(u_nk, nbins):
return np.array(forward), np.array(forward_error)
-def do_convergence(u_nk, tau=1, num_points=10):
+def do_convergence(u_nk, tau=1, num_points=10, estimator='BAR'):
"""
Convergence calculation. Incrementally adds data from either the start or the end of each windows simulation and calculates the resulting change in free energy.
Arguments: u_nk, tau (an error scaling factor), num_points (number of chunks)
Returns: forward-sampled estimate (starting from t=start), forward-sampled error, backward-sampled estimate (from t=end), backward-sampled error
"""
+ assert estimator in ['BAR', 'EXP'], "ERROR: I only know BAR and EXP estimators."
groups = u_nk.groupby("fep-lambda")
forward = []
@@ -210,18 +211,33 @@ def do_convergence(u_nk, tau=1, num_points=10):
for i in range(1, num_points + 1):
# forward
partial = subsample(groups, 0, 100 * i / num_points)
- estimate = BAR().fit(partial)
- l, l_mid, f, df, ddf, errors = get_BAR(estimate)
-
- forward.append(f.iloc[-1])
- forward_error.append(errors[-1])
-
+ if estimator=='BAR':
+ estimate = BAR().fit(partial)
+ l, l_mid, f, df, ddf, errors = get_BAR(estimate)
+ f_append = f.iloc[-1]
+ err_append = errors[-1]
+ else:
+ expl, expmid, dG_fs, dG_bs = get_exponential(partial)
+ f_append = np.average([dG_fs[-1], dG_bs[-1]])
+ err_append = np.std([dG_fs[-1], dG_bs[-1]])
+
+ forward.append(f_append)
+ forward_error.append(err_append)
+
+ # backward
partial = subsample(groups, 100 * (1 - i / num_points), 100)
- estimate = BAR().fit(partial)
- l, l_mid, f, df, ddf, errors = get_BAR(estimate)
+ if estimator=='BAR':
+ estimate = BAR().fit(partial)
+ l, l_mid, f, df, ddf, errors = get_BAR(estimate)
+ f_append = f.iloc[-1]
+ err_append = errors[-1]
+ else:
+ expl, expmid, dG_fs, dG_bs = get_exponential(partial)
+ f_append = np.average([dG_fs[-1], dG_bs[-1]])
+ err_append = np.std([dG_fs[-1], dG_bs[-1]])
- backward.append(f.iloc[-1])
- backward_error.append(errors[-1])
+ backward.append(f_append)
+ backward_error.append(err_append)
return (
np.array(forward),