diff --git a/Sample_Notebooks/BAR_Estimator_Basic.ipynb b/Sample_Notebooks/BAR_Estimator_Basic.ipynb index abe4375..7896a58 100644 --- a/Sample_Notebooks/BAR_Estimator_Basic.ipynb +++ b/Sample_Notebooks/BAR_Estimator_Basic.ipynb @@ -15,12 +15,35 @@ "import os\n", "from alchemlyb.parsing import namd\n", "from IPython.display import display, Markdown\n", - "\n", + "from pathlib import Path\n", + "from dataclasses import dataclass\n", + "import scipy as sp\n", "from alchemlyb.estimators import BAR\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf843fad", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "@dataclass\n", + "class FepRun:\n", + " u_nk: pd.DataFrame\n", + " perWindow: pd.DataFrame\n", + " cumulative: pd.DataFrame\n", + " forward: pd.DataFrame\n", + " forward_error: pd.DataFrame\n", + " backward: pd.DataFrame\n", + " backward_error: pd.DataFrame\n", + " per_lambda_convergence: pd.DataFrame\n", + " color: str" + ] + }, { "cell_type": "markdown", "id": "a9d06104", @@ -38,19 +61,25 @@ "metadata": {}, "outputs": [], "source": [ - "path='/path/to/data/'\n", - "filename='*.fepout'\n", + "dataroot = Path('.')\n", + "replica_pattern='Replica?'\n", + "replicas = dataroot.glob(replica_pattern)\n", + "filename_pattern='*.fepout'\n", "\n", "temperature = 303.15\n", "RT = 0.00198720650096 * temperature\n", - "decorrelate = True #Flag for decorrelation of samples\n", - "detectEQ = True #Flag for automated equilibrium detection\n", - "\n", - "fepoutFiles = glob(path+filename)\n", - "totalSize = 0\n", - "for file in fepoutFiles:\n", - " totalSize += os.path.getsize(file)\n", - "print(f\"Will process {len(fepoutFiles)} fepout files.\\nTotal size:{np.round(totalSize/10**9, 2)}GB\")" + "detectEQ = True #Flag for automated equilibrium detection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6864313", + "metadata": {}, + "outputs": [], + "source": [ + "colors = ['blue', 'red', 'green', 'purple', 'orange', 'violet', 'cyan']\n", + "itcolors = iter(colors)" ] }, { @@ -64,6 +93,14 @@ "Note: alchemlyb operates in units of kT by default. We multiply by RT to convert to units of kcal/mol." ] }, + { + "cell_type": "markdown", + "id": "80b1034d", + "metadata": {}, + "source": [ + "# Read and plot number of samples after detecting EQ" + ] + }, { "cell_type": "code", "execution_count": null, @@ -71,20 +108,44 @@ "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots()\n", - "u_nk = namd.extract_u_nk(fepoutFiles, temperature)\n", - "safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')\n", - "\n", - "if detectEQ:\n", - " print(\"Detecting equilibrium\")\n", - " u_nk = safep.detect_equilibrium_u_nk(u_nk)\n", - " safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')\n", - "if decorrelate and not detectEQ:\n", - " print(\"Decorrelating\")\n", - " u_nk = safep.decorrelate_u_nk(u_nk)\n", - " safep.plot_samples(ax, u_nk, color='green', label='Decorrelated')\n", + "fepruns = {}\n", + "for replica in replicas:\n", + " print(f\"Reading {replica}\")\n", + " unkpath = replica.joinpath('decorrelated.csv')\n", + " u_nk = None\n", + " if unkpath.is_file():\n", + " print(f\"Found existing dataframe. Reading.\")\n", + " u_nk = safep.read_UNK(unkpath)\n", + " else:\n", + " print(f\"Didn't find existing dataframe at {unkpath}. Checking for raw fepout files.\")\n", + " fepoutFiles = list(replica.glob(filename_pattern))\n", + " totalSize = 0\n", + " for file in fepoutFiles:\n", + " totalSize += os.path.getsize(file)\n", + " print(f\"Will process {len(fepoutFiles)} fepout files.\\nTotal size:{np.round(totalSize/10**9, 2)}GB\")\n", + "\n", + " if len(list(fepoutFiles))>0:\n", + " print(\"Reading fepout files\")\n", + " fig, ax = plt.subplots()\n", + "\n", + " u_nk = namd.extract_u_nk(fepoutFiles, temperature)\n", + " u_nk = u_nk.sort_index(axis=0, level=1).sort_index(axis=1)\n", + " safep.plot_samples(ax, u_nk, color='blue', label='Raw Data')\n", + "\n", + " if detectEQ:\n", + " print(\"Detecting equilibrium\")\n", + " u_nk = safep.detect_equilibrium_u_nk(u_nk)\n", + " safep.plot_samples(ax, u_nk, color='orange', label='Equilibrium-Detected')\n", + "\n", + " plt.savefig(f\"./{str(replica)}_FEP_number_of_samples.pdf\")\n", + " plt.show()\n", + " safep.save_UNK(u_nk, unkpath)\n", + " else:\n", + " print(f\"WARNING: no fepout files found for {replica}. Skipping.\")\n", " \n", - "plt.savefig(f\"{path}FEP_number_of_samples.pdf\")" + " if u_nk is not None:\n", + " fepruns[str(replica)] = FepRun(u_nk, None, None, None, None, None, None, None, next(itcolors))\n", + " " ] }, { @@ -94,9 +155,20 @@ "metadata": {}, "outputs": [], "source": [ - "perWindow, cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data\n", - "forward, forward_error, backward, backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'\n", - "per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)" + "for key, feprun in fepruns.items():\n", + " u_nk = feprun.u_nk\n", + " feprun.perWindow, feprun.cumulative = safep.do_estimation(u_nk) #Run the BAR estimator on the fep data\n", + " feprun.forward, feprun.forward_error, feprun.backward, feprun.backward_error = safep.do_convergence(u_nk) #Used later in the convergence plot'\n", + " feprun.per_lambda_convergence = safep.do_per_lambda_convergence(u_nk)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "caae40d3", + "metadata": {}, + "source": [ + "# Plot data" ] }, { @@ -106,11 +178,59 @@ "metadata": {}, "outputs": [], "source": [ - "dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)\n", - "error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)\n", + "toprint = \"\"\n", + "dGs = []\n", + "errors = []\n", + "for key, feprun in fepruns.items():\n", + " cumulative = feprun.cumulative\n", + " dG = np.round(cumulative.BAR.f.iloc[-1]*RT, 1)\n", + " error = np.round(cumulative.BAR.errors.iloc[-1]*RT, 1)\n", + " dGs.append(dG)\n", + " errors.append(error)\n", + "\n", + " changeAndError = f'{key}: \\u0394G = {dG}\\u00B1{error} kcal/mol'\n", + " toprint += '{}
'.format(changeAndError)\n", + "\n", + "toprint += '{}
'.format('__________________')\n", + "mean = np.average(dGs)\n", + "\n", + "#If there are only a few replicas, the MBAR estimated error will be more reliable, albeit underestimated\n", + "if len(dGs)<3:\n", + " sterr = np.sqrt(np.sum(np.square(errors)))\n", + "else:\n", + " sterr = np.round(np.std(dGs),1)\n", + "toprint += '{}
'.format(f'mean: {mean}')\n", + "toprint += '{}
'.format(f'sterr: {sterr}')\n", + "Markdown(toprint)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1fde633", + "metadata": {}, + "outputs": [], + "source": [ + "def do_agg_data(dataax, plotax):\n", + " agg_data = []\n", + " lines = dataax.lines\n", + " for line in lines:\n", + " agg_data.append(line.get_ydata())\n", + " flat = np.array(agg_data).flatten()\n", + " kernel = sp.stats.gaussian_kde(flat)\n", + " pdfX = np.linspace(-1, 1, 1000)\n", + " pdfY = kernel(pdfX)\n", + " std = np.std(flat)\n", + " mean = np.average(flat)\n", + " temp = pd.Series(pdfY, index=pdfX)\n", + " mode = temp.idxmax()\n", + "\n", + " textstr = r\"$\\rm mode=$\"+f\"{np.round(mode,2)}\"+\"\\n\"+fr\"$\\mu$={np.round(mean,2)}\"+\"\\n\"+fr\"$\\sigma$={np.round(std,2)}\"\n", + " props = dict(boxstyle='square', facecolor='white', alpha=1)\n", + " plotax.text(0.175, 0.95, textstr, transform=plotax.transAxes, fontsize=14,\n", + " verticalalignment='top', bbox=props)\n", "\n", - "changeAndError = f'\\u0394G = {dG}\\u00B1{error} kcal/mol'\n", - "Markdown('{}
'.format(changeAndError))" + " return plotax" ] }, { @@ -120,9 +240,21 @@ "metadata": {}, "outputs": [], "source": [ - "fig, axes = safep.plot_general(cumulative, None, perWindow, None, RT)\n", - "fig.suptitle(changeAndError)\n", - "plt.savefig(f'{path}FEP_general_figures.pdf')" + "fig = None\n", + "for key, feprun in fepruns.items():\n", + " if fig is None:\n", + " fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)\n", + " axes[1].legend()\n", + " else:\n", + " fig, axes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=fig, axes=axes, hysttype='lines', label=key, color=feprun.color)\n", + " #fig.suptitle(changeAndError)\n", + "\n", + "# hack to get aggregate data:\n", + "axes[3] = do_agg_data(axes[2], axes[3])\n", + "\n", + "axes[0].set_title(str(mean)+r'$\\pm$'+str(sterr)+' kcal/mol')\n", + "axes[0].legend()\n", + "plt.savefig(dataroot.joinpath('FEP_general_figures.pdf'))" ] }, { @@ -141,28 +273,59 @@ "outputs": [], "source": [ "fig, convAx = plt.subplots(1,1)\n", - "convAx = safep.convergence_plot(convAx, forward*RT, forward_error*RT, backward*RT, backward_error*RT)\n", - "plt.savefig(f'{path}FEP_convergence.pdf')" + "\n", + "for key, feprun in fepruns.items():\n", + " convAx = safep.convergence_plot(convAx, \n", + " feprun.forward*RT, \n", + " feprun.forward_error*RT, \n", + " feprun.backward*RT,\n", + " feprun.backward_error*RT,\n", + " fwd_color=feprun.color,\n", + " bwd_color=feprun.color,\n", + " errorbars=False\n", + " )\n", + " convAx.get_legend().remove()\n", + "\n", + "forward_line, = convAx.plot([],[],linestyle='-', color='black', label='Forward Time Sampling')\n", + "backward_line, = convAx.plot([],[],linestyle='--', color='black', label='Backward Time Sampling')\n", + "convAx.legend(handles=[forward_line, backward_line])\n", + "ymin = np.min(dGs)-1\n", + "ymax = np.max(dGs)+1\n", + "convAx.set_ylim((ymin,ymax))\n", + "plt.savefig(dataroot.joinpath('FEP_convergence.pdf'))" ] }, { "cell_type": "code", "execution_count": null, - "id": "1bf6ac36-f102-4da3-82d2-36a7741584e5", + "id": "e6ab4621", "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots()\n", - "ax.errorbar(per_lambda_convergence.index, per_lambda_convergence.BAR.df*RT)\n", - "ax.set_xlabel(r\"$\\lambda$\")\n", - "ax.set_ylabel(r\"$D_{last-first}$ (kcal/mol)\")\n", - "plt.savefig(f\"{path}FEP_perLambda_convergence.pdf\")" + "genfig = None\n", + "for key, feprun in fepruns.items():\n", + " if genfig is None:\n", + " genfig, genaxes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, hysttype='lines', label=key, color=feprun.color)\n", + " else:\n", + " genfig, genaxes = safep.plot_general(feprun.cumulative, None, feprun.perWindow, None, RT, fig=genfig, axes=genaxes, hysttype='lines', label=key, color=feprun.color)\n", + "plt.delaxes(genaxes[0])\n", + "plt.delaxes(genaxes[1])\n", + "\n", + "genaxes[3] = do_agg_data(axes[2], axes[3])\n", + "genaxes[2].set_title(str(mean)+r'$\\pm$'+str(sterr)+' kcal/mol')\n", + "\n", + "for txt in genfig.texts:\n", + " print(1)\n", + " txt.set_visible(False)\n", + " txt.set_text(\"\")\n", + "plt.show()\n", + "plt.savefig(dataroot.joinpath('FEP_perLambda_convergence.pdf'))" ] }, { "cell_type": "code", "execution_count": null, - "id": "b237a8ee-f33e-4c05-b4ed-856bb4f34ca2", + "id": "96418460", "metadata": {}, "outputs": [], "source": [] @@ -184,7 +347,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/Sample_Notebooks/Replica1 b/Sample_Notebooks/Replica1 new file mode 120000 index 0000000..df9f7a8 --- /dev/null +++ b/Sample_Notebooks/Replica1 @@ -0,0 +1 @@ +Sample_Data \ No newline at end of file diff --git a/Sample_Notebooks/Replica2 b/Sample_Notebooks/Replica2 new file mode 120000 index 0000000..df9f7a8 --- /dev/null +++ b/Sample_Notebooks/Replica2 @@ -0,0 +1 @@ +Sample_Data \ No newline at end of file diff --git a/Sample_Notebooks/Replica3 b/Sample_Notebooks/Replica3 new file mode 120000 index 0000000..df9f7a8 --- /dev/null +++ b/Sample_Notebooks/Replica3 @@ -0,0 +1 @@ +Sample_Data \ No newline at end of file diff --git a/safep/plotting.py b/safep/plotting.py index e377f19..a3471cf 100644 --- a/safep/plotting.py +++ b/safep/plotting.py @@ -187,9 +187,9 @@ def plot_general(cumulative, # Per-window change in kcal/mol if errorbars: - each_ax.errorbar(per_window.index, per_window.BAR.df*RT, yerr=per_window.BAR.ddf, marker=None, linewidth=1, color=color) - each_ax.errorbar(per_window.index, -per_window.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5, linestyle='--', color=color) - each_ax.plot(per_window.index, per_window.EXP.dG_f*RT, marker=None, linewidth=1, alpha=0.5, color=color) + each_ax.errorbar(per_window.index, per_window.BAR.df*RT, yerr=per_window.BAR.ddf, marker=None, linewidth=1, alpha=0.5, color=color, label="forward EXP") + each_ax.errorbar(per_window.index, -per_window.EXP.dG_b*RT, marker=None, linewidth=1, alpha=0.5, linestyle='--', color=color, label="backward EXP") + each_ax.plot(per_window.index, per_window.EXP.dG_f*RT, marker=None, linewidth=1, color=color, label="BAR") each_ax.set(ylabel=r'$\mathrm{\Delta} G_\lambda$'+'\n'+r'$\left(kcal/mol\right)$', ylim=per_window_ylim) diff --git a/safep/processing.py b/safep/processing.py index 1c6c58a..c9222d6 100644 --- a/safep/processing.py +++ b/safep/processing.py @@ -195,12 +195,13 @@ def alt_convergence(u_nk, nbins): return np.array(forward), np.array(forward_error) -def do_convergence(u_nk, tau=1, num_points=10): +def do_convergence(u_nk, tau=1, num_points=10, estimator='BAR'): """ Convergence calculation. Incrementally adds data from either the start or the end of each windows simulation and calculates the resulting change in free energy. Arguments: u_nk, tau (an error scaling factor), num_points (number of chunks) Returns: forward-sampled estimate (starting from t=start), forward-sampled error, backward-sampled estimate (from t=end), backward-sampled error """ + assert estimator in ['BAR', 'EXP'], "ERROR: I only know BAR and EXP estimators." groups = u_nk.groupby("fep-lambda") forward = [] @@ -210,18 +211,33 @@ def do_convergence(u_nk, tau=1, num_points=10): for i in range(1, num_points + 1): # forward partial = subsample(groups, 0, 100 * i / num_points) - estimate = BAR().fit(partial) - l, l_mid, f, df, ddf, errors = get_BAR(estimate) - - forward.append(f.iloc[-1]) - forward_error.append(errors[-1]) - + if estimator=='BAR': + estimate = BAR().fit(partial) + l, l_mid, f, df, ddf, errors = get_BAR(estimate) + f_append = f.iloc[-1] + err_append = errors[-1] + else: + expl, expmid, dG_fs, dG_bs = get_exponential(partial) + f_append = np.average([dG_fs[-1], dG_bs[-1]]) + err_append = np.std([dG_fs[-1], dG_bs[-1]]) + + forward.append(f_append) + forward_error.append(err_append) + + # backward partial = subsample(groups, 100 * (1 - i / num_points), 100) - estimate = BAR().fit(partial) - l, l_mid, f, df, ddf, errors = get_BAR(estimate) + if estimator=='BAR': + estimate = BAR().fit(partial) + l, l_mid, f, df, ddf, errors = get_BAR(estimate) + f_append = f.iloc[-1] + err_append = errors[-1] + else: + expl, expmid, dG_fs, dG_bs = get_exponential(partial) + f_append = np.average([dG_fs[-1], dG_bs[-1]]) + err_append = np.std([dG_fs[-1], dG_bs[-1]]) - backward.append(f.iloc[-1]) - backward_error.append(errors[-1]) + backward.append(f_append) + backward_error.append(err_append) return ( np.array(forward),