Skip to content

Commit

Permalink
[WIP] Add chapter on results / feature importances / application stud…
Browse files Browse the repository at this point in the history
…y 👨‍💻 (#403)
  • Loading branch information
KarelZe authored Jun 11, 2023
1 parent 508cded commit 41e5c04
Show file tree
Hide file tree
Showing 33 changed files with 2,744 additions and 902 deletions.
385 changes: 297 additions & 88 deletions notebooks/4.0c-mb-feature-importances.ipynb

Large diffs are not rendered by default.

1,138 changes: 1,138 additions & 0 deletions notebooks/4.0f-mb-results-own-rule.ipynb

Large diffs are not rendered by default.

161 changes: 152 additions & 9 deletions notebooks/6.0a-mb-visualizations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"import optuna\n",
"import wandb\n",
"\n",
"import seaborn as sns\n",
"\n",
"os.environ[\"GCLOUD_PROJECT\"] = \"flowing-mantis-239216\""
]
},
Expand Down Expand Up @@ -72,6 +74,137 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Confusion Matrices"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"\n",
"def plt_cm(pos, clf, clf_name, cbar_ax=None):\n",
" # https://medium.com/@dtuk81/confusion-matrix-visualization-fc31e3f30fea\n",
" \n",
" cf_matrix = confusion_matrices_cboe.iloc[clf].values[0]\n",
" \n",
" labels_ax = [\"-1 (Sell)\", \"1 (Buy)\"]\n",
" group_names = [\"True Neg\",\"False Pos\",\"False Neg\",\"True Pos\"]\n",
" group_counts = [\"{0:,}\".format(value) for value in\n",
" cf_matrix.flatten()]\n",
"\n",
" # https://github.com/scikit-learn/scikit-learn/blob/364c77e04/sklearn/metrics/_classification.py#L232\n",
" perc = cf_matrix/cf_matrix.sum(axis=1, keepdims=True)\n",
"\n",
" group_percentages = [f\"{value*100:.2f}\\%\" for value in perc.flatten()]\n",
" labels = [f\"{v3} \\n ({v2})\" for v1, v2, v3 in\n",
" zip(group_names,group_counts,group_percentages)]\n",
"\n",
" labels = np.asarray(labels).reshape(2,2)\n",
"\n",
" norm = plt.Normalize(0,1)\n",
"\n",
" s = sns.heatmap(perc, annot=labels, fmt=\"\", cmap='Blues', xticklabels=labels_ax, yticklabels=labels_ax, norm=norm, ax=ax[pos], cbar=cbar_ax is not None, annot_kws={\"fontsize\":8}, square=True, vmin=0, vmax=1,cbar_ax=cbar_ax)\n",
" # ax[pos].set_xlabel('Predicted Label')\n",
" # ax[pos].set_ylabel('True Label')\n",
" s.set_title(clf_name)\n",
" s.set(xlabel=\"\", ylabel=\"\")\n",
" \n",
" \n",
" # s.xaxis.tick_bottom()\n",
" \n",
" return s\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"exchange = \"ise\"\n",
"mode = \"supervised\"\n",
"\n",
"confusion_matrices_cboe = pd.read_pickle(f\"gs://thesis-bucket-option-trade-classification/data/results/{exchange}_{mode}_test-confusion-matrices.pickle\")\n",
"\n",
"fig, ax = plt.subplots(2,3,figsize=(14*CM,10*CM), sharey=True, sharex=True, tight_layout=True)\n",
"cbar_ax = fig.add_axes([0.97, .3, .03, .4])\n",
"# cbar_ax.xaxis.set_major_formatter(PercentFormatter(100.0, 2))\n",
"\n",
"plt_cm((0,0), 0, \"Transformer (FS 1)\" , cbar_ax)\n",
"# if cbar_ax is not None:\n",
"ax[(0,0)].collections[0].colorbar.ax.yaxis.set_major_formatter(PercentFormatter(1, 0))\n",
"\n",
"plt_cm((0, 1), 1, \"Transformer (FS 2)\" )\n",
"plt_cm((0, 2), 2, \"Transformer (FS 3)\" )\n",
"\n",
"plt_cm((1,0), 3, \"GBRT (FS 1)\" )\n",
"plt_cm((1, 1), 4, \"GBRT (FS 2)\" )\n",
"plt_cm((1, 2), 5, \"GBRT (FS 3)\" )\n",
"\n",
"# plt.yaxis(\"Predicted Label\")\n",
"# plt.xlabel(\"True Label\")\n",
"\n",
"plt.tight_layout()\n",
"\n",
"# ax[(0,0)].cax.colorbar(s)\n",
"# ax[(0,0)].cax.toggle_label(True)\n",
"\n",
"# fig.subplots_adjust(right=0.8)\n",
"# cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n",
"# fig.colorbar(s, cax=cbar_ax)\n",
"fig.supxlabel('Predicted Label')\n",
"fig.supylabel('True Label')\n",
"\n",
"plt.tight_layout(pad=0.7)\n",
"plt.savefig(f\"../reports/Graphs/confusion_matrix_{exchange}.pdf\", bbox_inches=\"tight\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"cf_matrix = confusion_matrix(X_print[\"buy_sell\"].astype(\"int8\"), X_print[(\"fttransformer\", \"fttransformer(classical)\")], labels=[-1,1])\n",
"\n",
"# https://medium.com/@dtuk81/confusion-matrix-visualization-fc31e3f30fea\n",
"labels_ax = [\"-1 (sell)\", \"1 (buy)\"]\n",
"group_names = [\"True Neg\",\"False Pos\",\"False Neg\",\"True Pos\"]\n",
"group_counts = [\"{0:,}\".format(value) for value in\n",
" cf_matrix.flatten()]\n",
"\n",
"# https://github.com/scikit-learn/scikit-learn/blob/364c77e04/sklearn/metrics/_classification.py#L232\n",
"perc = cf_matrix/cf_matrix.sum(axis=1, keepdims=True)\n",
"\n",
"group_percentages = [\"{0:.2%}\".format(value) for value in perc.flatten()]\n",
"labels = [f\"{v3} \\n ({v2})\" for v1, v2, v3 in\n",
" zip(group_names,group_counts,group_percentages)]\n",
"\n",
"labels = np.asarray(labels).reshape(2,2)\n",
"\n",
"norm = plt.Normalize(0,np.max(perc))\n",
"\n",
"sns.heatmap(perc, annot=labels, fmt=\"\", cmap='Blues', xticklabels=labels_ax, yticklabels=labels_ax, norm=norm)\n",
"plt.xlabel('Predicted Label')\n",
"plt.ylabel('True Label')\n",
"plt.title(\"Transformer (FS Classical)\")\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -103,17 +236,17 @@
"\n",
"ax[0].plot(accuracies_over_time_ise[\"tick(all)\"], label=\"$\\operatorname{tick}_{\\mathrm{all}}$\", lw=1)\n",
"ax[0].plot(accuracies_over_time_ise[\"quote(best)\"], label=\"$\\operatorname{quote}_{\\mathrm{nbbo}}$\", lw=1, zorder=20)\n",
"ax[0].plot(accuracies_over_time_ise[\"quote(best)->quote(ex)->rev_tick(all)\"], label=r\"$\\operatorname{quote}_{\\mathrm{nbbo}} \\to \\operatorname{quote}_{\\mathrm{ex}} \\to \\operatorname{rtick}_{\\mathrm{all}}$\", lw=1, zorder=50)\n",
"ax[0].plot(accuracies_over_time_ise[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all)\"], label=r\"$\\operatorname{gsu}$\", lw=1, zorder=100)\n",
"ax[0].plot(accuracies_over_time_ise[\"quote(best)->quote(ex)->rev_tick(all)\"], label=r\"$\\operatorname{gsu}_{\\mathrm{small}}$\", lw=1, zorder=50)\n",
"ax[0].plot(accuracies_over_time_ise[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all)\"], label=r\"$\\operatorname{gsu}_{\\mathrm{large}}$\", lw=1, zorder=100)\n",
"\n",
"ax[0].axvline(x=pd.Timestamp('2013-10-24'), linestyle='--', color='grey', linewidth=0.5)\n",
"ax[0].axvline(x=pd.Timestamp('2015-11-05'), linestyle='--', color='grey', linewidth=0.5)\n",
"\n",
"# ax[1].s\n",
"ax[1].plot(accuracies_over_time_cboe[\"tick(all)\"], label=\"$\\operatorname{tick}_{\\mathrm{all}}$\", lw=1)\n",
"ax[1].plot(accuracies_over_time_cboe[\"quote(best)\"], label=\"$\\operatorname{quote}_{\\mathrm{nbbo}}$\", lw=1, zorder=20)\n",
"ax[1].plot(accuracies_over_time_cboe[\"quote(best)->quote(ex)->rev_tick(all)\"], label=r\"$\\operatorname{quote}_{\\mathrm{nbbo}} \\to \\operatorname{quote}_{\\mathrm{ex}} \\to \\operatorname{rtick}_{\\mathrm{all}}$\", lw=1, zorder=50)\n",
"ax[1].plot(accuracies_over_time_cboe[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all)\"], label=r\"$\\operatorname{gsu}$\", lw=1, zorder=100)\n",
"ax[1].plot(accuracies_over_time_cboe[\"quote(best)->quote(ex)->rev_tick(all)\"], label=r\"$\\operatorname{gsu}_{\\mathrm{small}}$\", lw=1, zorder=50)\n",
"ax[1].plot(accuracies_over_time_cboe[\"trade_size(ex)->quote(best)->quote(ex)->depth(best)->depth(ex)->rev_tick(all)\"], label=r\"$\\operatorname{gsu}_{\\mathrm{large}}$\", lw=1, zorder=100)\n",
"\n",
"ax[1].axvline(x=pd.Timestamp('2015-11-05'), linestyle='--', color='grey', linewidth=0.5)\n",
"\n",
Expand All @@ -134,7 +267,7 @@
"\n",
"handles, labels = ax[1].get_legend_handles_labels()\n",
"order = [0, 1, 2, 3]\n",
"ax[1].legend([handles[idx] for idx in order],[labels[idx] for idx in order], frameon=False, loc=\"lower center\", ncols=2, bbox_to_anchor=(0.5, -1))\n",
"ax[1].legend([handles[idx] for idx in order],[labels[idx] for idx in order], frameon=False, loc=\"lower center\", ncols=4, bbox_to_anchor=(0.5, -0.5))\n",
"\n",
"ax[0].set_title('ISE')\n",
"ax[1].set_title('CBOE')\n",
Expand All @@ -148,10 +281,11 @@
"# plt.ylabel(\"Accuracy\")\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(\"../reports/Graphs/accuracies_over_time.pdf\", bbox_inches=\"tight\")"
"plt.savefig(\"../reports/Graphs/classical_accuracies_over_time.pdf\", bbox_inches=\"tight\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -200,6 +334,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "kNTG2a_kf5gS"
Expand Down Expand Up @@ -251,6 +386,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "vVE2JK9Af5gW"
Expand Down Expand Up @@ -291,6 +427,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "h9mAHJU1f5gX"
Expand Down Expand Up @@ -342,6 +479,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "rdBVk3fyf5gZ"
Expand Down Expand Up @@ -487,6 +625,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "SyA46Ie6f5gc"
Expand Down Expand Up @@ -593,6 +732,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "KLKHwCjOf5gg"
Expand Down Expand Up @@ -716,6 +856,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "jGL-HbYlf5gi"
Expand Down Expand Up @@ -1240,6 +1381,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "roRmlg_nf5gl"
Expand Down Expand Up @@ -1590,6 +1732,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "r5ZnoZIG26K_"
Expand Down Expand Up @@ -1701,9 +1844,9 @@
"provenance": []
},
"kernelspec": {
"display_name": "thesis",
"display_name": "Python 3",
"language": "python",
"name": "thesis"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1715,7 +1858,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.9.4"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 41e5c04

Please sign in to comment.