From fec64c0e1ff3fe828a1277fa5fc6b398fc097176 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 8 Jul 2024 14:47:32 +0300 Subject: [PATCH] Update CI, fix black and isort issues --- .github/workflows/main.yaml | 6 +- plots/create_csvs/for_paper.ipynb | 105 +++++--- plots/fid_dispersion/fid_dispersion.ipynb | 58 ++--- .../diff_vae_for_paper.ipynb | 107 +++----- plots/grid_search/msg_data_analysis.ipynb | 241 +++++++++--------- pyproject.toml | 26 +- requirements.dev.txt | 36 +-- src/metr/finetune_ldm_decoder.py | 4 +- src/metr/metr_pp_eval_stable_sig.py | 4 +- src/metr/run_metr.py | 8 +- src/metr/run_metr_fid.py | 3 +- 11 files changed, 255 insertions(+), 343 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 29efb47..68c5d8f 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -13,12 +13,8 @@ jobs: cache: "pip" - name: "installation" run: | - pip install -r requirements.txt -r requirements.dev.txt + pip install -r requirements.dev.txt - name: "black" run: black . --check --diff --color - name: "isort" run: isort . --check --diff - - name: "mypy" - run: mypy - - name: "pytests" - run: pytest diff --git a/plots/create_csvs/for_paper.ipynb b/plots/create_csvs/for_paper.ipynb index 4aa94dd..aae20fe 100644 --- a/plots/create_csvs/for_paper.ipynb +++ b/plots/create_csvs/for_paper.ipynb @@ -14,6 +14,7 @@ "import os\n", "\n", "import matplotlib.pyplot as plt\n", + "\n", "sns.set_style(\"white\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})" ] }, @@ -23,10 +24,10 @@ "metadata": {}, "outputs": [], "source": [ - "def get_runs_df_stable_sig(project, entity=\"jurujin\", runtime_limit=6*3600):\n", - " '''\n", + "def get_runs_df_stable_sig(project, entity=\"jurujin\", runtime_limit=6 * 3600):\n", + " \"\"\"\n", " Returns df with data from wandb project for stable-sig\n", - " '''\n", + " \"\"\"\n", " api = wandb.Api()\n", " runs = api.runs(entity + \"/\" + project)\n", "\n", @@ -36,7 +37,6 @@ " config_list.append({k: v for k, v in run.config.items() if not k.startswith(\"_\")})\n", " name_list.append(run.name)\n", "\n", - "\n", " summary_df = pd.DataFrame(summary_list)\n", " config_df = pd.DataFrame(config_list)\n", "\n", @@ -55,10 +55,11 @@ "source": [ "# Чтобы выбрать долгие FID запуски надо: df[\"_runtime\"] > 7200\n", "\n", - "def get_runs_df(project, entity=\"jurujin\", runtime_limit=6*3600, resolution=False):\n", - " '''\n", + "\n", + "def get_runs_df(project, entity=\"jurujin\", runtime_limit=6 * 3600, resolution=False):\n", + " \"\"\"\n", " Returns df with data from wandb project\n", - " '''\n", + " \"\"\"\n", " df = get_runs_df_stable_sig(project, entity, runtime_limit)\n", "\n", " if resolution:\n", @@ -75,67 +76,79 @@ "source": [ "detection_projects = [\n", " \"detect_msg_all_att_vae\",\n", - " \"detect_msg_all_att_no_vae\", \n", - "\n", + " \"detect_msg_all_att_no_vae\",\n", " # \"clip_different_msg\" # Testing CLIP quality for different message\n", "]\n", "\n", - "stable_signature_detection_projects = [\n", - " \"eval_stable_tree_all_attacks\"\n", - "]\n", + "stable_signature_detection_projects = [\"eval_stable_tree_all_attacks\"]\n", "\n", "fid_projects = [\n", " # \"fid_gt_msg_all_att_vae\",\n", " # \"fid_gt_msg_all_att_no_vae\",\n", - "\n", " # \"fid_gen_msg_all_att_vae\",\n", " # \"fid_gen_msg_all_att_no_vae\",\n", - "\n", " \"fid_gen_message_dependency\",\n", " \"fid_gt_message_dependency\",\n", - " \n", "]\n", "\n", "detection_cols = [\n", " \"name\",\n", - " \"TPR@1%FPR\", \"auc\", \"acc\",\n", - " \"Bit_acc\", \"Word_acc\",\n", + " \"TPR@1%FPR\",\n", + " \"auc\",\n", + " \"acc\",\n", + " \"Bit_acc\",\n", + " \"Word_acc\",\n", " \"det_resol\",\n", - "\n", " \"w_clip_score_mean\",\n", - "\n", " \"w_det_dist_mean\",\n", " \"no_w_det_dist_mean\",\n", - "\n", " \"w_det_dist_std\",\n", " \"no_w_det_dist_std\",\n", - " \n", - " \"msg\", \"w_radius\", \"msg_scaler\",\n", - "\n", - " \"jpeg_ratio\", \"crop_scale\", \"crop_ratio\", \"gaussian_blur_r\", \"gaussian_std\", \"brightness_factor\", \"r_degree\"\n", + " \"msg\",\n", + " \"w_radius\",\n", + " \"msg_scaler\",\n", + " \"jpeg_ratio\",\n", + " \"crop_scale\",\n", + " \"crop_ratio\",\n", + " \"gaussian_blur_r\",\n", + " \"gaussian_std\",\n", + " \"brightness_factor\",\n", + " \"r_degree\",\n", "]\n", "\n", - "stable_signature_detection_cols = [\n", - " \"name\", \"Bit_acc\", \"Word_acc\"\n", - "]\n", + "stable_signature_detection_cols = [\"name\", \"Bit_acc\", \"Word_acc\"]\n", "\n", "fid_cols = [\n", " \"name\",\n", - " \"psnr_w\", \"ssim_w\",\n", - " \"psnr_no_w\", \"ssim_no_w\",\n", - " \"fid_w\", \"fid_no_w\",\n", - " \"msg\", \"w_radius\", \"msg_scaler\",\n", - "\n", + " \"psnr_w\",\n", + " \"ssim_w\",\n", + " \"psnr_no_w\",\n", + " \"ssim_no_w\",\n", + " \"fid_w\",\n", + " \"fid_no_w\",\n", + " \"msg\",\n", + " \"w_radius\",\n", + " \"msg_scaler\",\n", "]\n", "\n", "fid_att_cols = [\n", " \"name\",\n", - " \"psnr_w\", \"ssim_w\",\n", - " \"psnr_no_w\", \"ssim_no_w\",\n", - " \"fid_w\", \"fid_no_w\",\n", - " \"msg\", \"w_radius\", \"msg_scaler\",\n", - "\n", - " \"jpeg_ratio\", \"crop_scale\", \"crop_ratio\", \"gaussian_blur_r\", \"gaussian_std\", \"brightness_factor\", \"r_degree\"\n", + " \"psnr_w\",\n", + " \"ssim_w\",\n", + " \"psnr_no_w\",\n", + " \"ssim_no_w\",\n", + " \"fid_w\",\n", + " \"fid_no_w\",\n", + " \"msg\",\n", + " \"w_radius\",\n", + " \"msg_scaler\",\n", + " \"jpeg_ratio\",\n", + " \"crop_scale\",\n", + " \"crop_ratio\",\n", + " \"gaussian_blur_r\",\n", + " \"gaussian_std\",\n", + " \"brightness_factor\",\n", + " \"r_degree\",\n", "]" ] }, @@ -147,7 +160,9 @@ "source": [ "for project in detection_projects:\n", " os.makedirs(\"./detection\", exist_ok=True)\n", - " get_runs_df(project, resolution=True, runtime_limit=4 * 3600).to_csv(f\"./detection/{project}.csv\", index=False, columns=detection_cols)" + " get_runs_df(project, resolution=True, runtime_limit=4 * 3600).to_csv(\n", + " f\"./detection/{project}.csv\", index=False, columns=detection_cols\n", + " )" ] }, { @@ -158,7 +173,9 @@ "source": [ "for project in fid_projects:\n", " os.makedirs(\"./fid\", exist_ok=True)\n", - " get_runs_df(project).sort_values(by=\"name\", ascending=False).to_csv(f\"./fid/{project}.csv\", index=False, columns=fid_cols)" + " get_runs_df(project).sort_values(by=\"name\", ascending=False).to_csv(\n", + " f\"./fid/{project}.csv\", index=False, columns=fid_cols\n", + " )" ] }, { @@ -169,7 +186,9 @@ "source": [ "for project in stable_signature_detection_projects:\n", " os.makedirs(\"./detection\", exist_ok=True)\n", - " get_runs_df_stable_sig(project, runtime_limit=0).to_csv(f\"./detection/{project}.csv\", index=False, columns=stable_signature_detection_cols)" + " get_runs_df_stable_sig(project, runtime_limit=0).to_csv(\n", + " f\"./detection/{project}.csv\", index=False, columns=stable_signature_detection_cols\n", + " )" ] }, { @@ -180,7 +199,9 @@ "source": [ "clip_different_message = \"clip_different_msg\"\n", "os.makedirs(\"./detection\", exist_ok=True)\n", - "get_runs_df(clip_different_message, runtime_limit=0, resolution=True).sort_values(by=\"name\", ascending=False).to_csv(f\"./detection/{clip_different_message}.csv\", index=False, columns=detection_cols)" + "get_runs_df(clip_different_message, runtime_limit=0, resolution=True).sort_values(by=\"name\", ascending=False).to_csv(\n", + " f\"./detection/{clip_different_message}.csv\", index=False, columns=detection_cols\n", + ")" ] } ], diff --git a/plots/fid_dispersion/fid_dispersion.ipynb b/plots/fid_dispersion/fid_dispersion.ipynb index 123d67c..554eca0 100644 --- a/plots/fid_dispersion/fid_dispersion.ipynb +++ b/plots/fid_dispersion/fid_dispersion.ipynb @@ -13,6 +13,7 @@ "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", + "\n", "sns.set_style(\"white\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})" ] }, @@ -22,10 +23,10 @@ "metadata": {}, "outputs": [], "source": [ - "def get_runs_df_stable_sig(project, entity=\"jurujin\", runtime_limit=6*3600):\n", - " '''\n", + "def get_runs_df_stable_sig(project, entity=\"jurujin\", runtime_limit=6 * 3600):\n", + " \"\"\"\n", " Returns df with data from wandb project for stable-sig\n", - " '''\n", + " \"\"\"\n", " api = wandb.Api()\n", " runs = api.runs(entity + \"/\" + project)\n", "\n", @@ -35,7 +36,6 @@ " config_list.append({k: v for k, v in run.config.items() if not k.startswith(\"_\")})\n", " name_list.append(run.name)\n", "\n", - "\n", " summary_df = pd.DataFrame(summary_list)\n", " config_df = pd.DataFrame(config_list)\n", "\n", @@ -45,12 +45,14 @@ "\n", " return df\n", "\n", + "\n", "# Чтобы выбрать долгие FID запуски надо: df[\"_runtime\"] > 7200\n", "\n", - "def get_runs_df(project, entity=\"jurujin\", runtime_limit=6*3600, resolution=False):\n", - " '''\n", + "\n", + "def get_runs_df(project, entity=\"jurujin\", runtime_limit=6 * 3600, resolution=False):\n", + " \"\"\"\n", " Returns df with data from wandb project\n", - " '''\n", + " \"\"\"\n", " df = get_runs_df_stable_sig(project, entity, runtime_limit)\n", "\n", " if resolution:\n", @@ -69,24 +71,15 @@ " # Все месседжи кроме 00..0 для всех S кроме 100\n", " \"fid_gеn_message_s\",\n", " \"fid_gt_message_s\",\n", - "\n", " # Все месседжи для S=100\n", " \"fid_gen_message_dependency\",\n", " \"fid_gt_message_dependency\",\n", - "\n", " # Месседжи 00..0 кроме S=100\n", " \"worst_message_fid_gen\",\n", - " \"worst_message_fid_gt\"\n", + " \"worst_message_fid_gt\",\n", "]\n", "\n", - "fid_cols = [\n", - " \"name\",\n", - " \"msg\",\n", - " \"msg_scaler\",\n", - " \"fid_w\",\n", - " \"target_clean_generated\"\n", - "\n", - "]" + "fid_cols = [\"name\", \"msg\", \"msg_scaler\", \"fid_w\", \"target_clean_generated\"]" ] }, { @@ -168,12 +161,9 @@ " stds = [df[df[\"msg_scaler\"] == S][\"fid_w\"].std() for S in range(60, 110, 10)]\n", " means = [df[df[\"msg_scaler\"] == S][\"fid_w\"].mean() for S in range(60, 110, 10)]\n", "\n", - " stats_dict = {\n", - " \"mean\":means,\n", - " \"std\":stds\n", - " }\n", - " \n", - " return pd.DataFrame(data=stats_dict, index=range(60, 110, 10))\n" + " stats_dict = {\"mean\": means, \"std\": stds}\n", + "\n", + " return pd.DataFrame(data=stats_dict, index=range(60, 110, 10))" ] }, { @@ -200,7 +190,7 @@ "metadata": {}, "outputs": [], "source": [ - "markersize=1\n", + "markersize = 1\n", "ticks_font = 16\n", "label_font = 20\n", "legend_font = 20" @@ -226,8 +216,10 @@ "fig, axs = plt.subplots(2, 1, figsize=(8, 6))\n", "\n", "S = range(60, 110, 10)\n", - "axs[0].errorbar(S, gt_stat[\"mean\"], yerr=gt_stat[\"std\"], label='FID gt', fmt='-o', markersize=markersize, c=\"#1f77b4\")\n", - "axs[1].errorbar(S, gen_stat[\"mean\"], yerr=gen_stat[\"std\"], label='FID gen', fmt='-o', markersize=markersize, c=\"#ff7f0e\")\n", + "axs[0].errorbar(S, gt_stat[\"mean\"], yerr=gt_stat[\"std\"], label=\"FID gt\", fmt=\"-o\", markersize=markersize, c=\"#1f77b4\")\n", + "axs[1].errorbar(\n", + " S, gen_stat[\"mean\"], yerr=gen_stat[\"std\"], label=\"FID gen\", fmt=\"-o\", markersize=markersize, c=\"#ff7f0e\"\n", + ")\n", "\n", "axs[0].set_ylim(25, 30.5)\n", "axs[0].set_yticks(np.arange(25, 31, 1))\n", @@ -236,19 +228,17 @@ "axs[1].set_yticks(np.arange(9, 19, 2))\n", "\n", "for ax in axs:\n", - " ax.legend(loc=\"upper left\" ,fontsize=legend_font)\n", - " ax.patch.set_edgecolor('black') \n", + " ax.legend(loc=\"upper left\", fontsize=legend_font)\n", + " ax.patch.set_edgecolor(\"black\")\n", " ax.patch.set_linewidth(1)\n", " ax.grid(alpha=0.5)\n", " ax.set_xlim(58, 102)\n", - " ax.set_xticks(\n", - " np.arange(60, 110, 10)\n", - " )\n", - " ax.tick_params(axis='both', which='major', labelsize=ticks_font)\n", + " ax.set_xticks(np.arange(60, 110, 10))\n", + " ax.tick_params(axis=\"both\", which=\"major\", labelsize=ticks_font)\n", "\n", "ax.set_xlabel(\"$S$ hyperparameter\", fontsize=label_font)\n", "\n", - "plt.savefig(\"fid_dispersion.png\", bbox_inches='tight')" + "plt.savefig(\"fid_dispersion.png\", bbox_inches=\"tight\")" ] } ], diff --git a/plots/generative_attacks/diff_vae_for_paper.ipynb b/plots/generative_attacks/diff_vae_for_paper.ipynb index e99fb39..d881d27 100644 --- a/plots/generative_attacks/diff_vae_for_paper.ipynb +++ b/plots/generative_attacks/diff_vae_for_paper.ipynb @@ -15,6 +15,7 @@ "import matplotlib.gridspec as gridspec\n", "\n", "import seaborn as sns\n", + "\n", "# sns.set_style(\"darkgrid\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})" ] }, @@ -25,9 +26,9 @@ "outputs": [], "source": [ "def get_runs_df(project, entity=\"jurujin\"):\n", - " '''\n", + " \"\"\"\n", " Returns df with data from wandb project for stable-sig\n", - " '''\n", + " \"\"\"\n", " api = wandb.Api()\n", " runs = api.runs(entity + \"/\" + project)\n", "\n", @@ -37,7 +38,6 @@ " config_list.append({k: v for k, v in run.config.items() if not k.startswith(\"_\")})\n", " name_list.append(run.name)\n", "\n", - "\n", " summary_df = pd.DataFrame(summary_list)\n", " config_df = pd.DataFrame(config_list)\n", "\n", @@ -53,7 +53,7 @@ "metadata": {}, "outputs": [], "source": [ - "markersize=6\n", + "markersize = 6\n", "ticks_font = 20\n", "label_font = 26\n", "legend_font = 22" @@ -67,7 +67,11 @@ "source": [ "df_msg = get_runs_df(\"diff_attacks_metr\")\n", "\n", - "df_msg = df_msg.loc[:, [\"Word_acc\", \"Bit_acc\", \"diff_attack_steps\", '''TPR@1%FPR''', \"acc\", \"auc\"]].sort_values(by=\"diff_attack_steps\", ascending=False).iloc[1::2, :]" + "df_msg = (\n", + " df_msg.loc[:, [\"Word_acc\", \"Bit_acc\", \"diff_attack_steps\", \"\"\"TPR@1%FPR\"\"\", \"acc\", \"auc\"]]\n", + " .sort_values(by=\"diff_attack_steps\", ascending=False)\n", + " .iloc[1::2, :]\n", + ")" ] }, { @@ -81,7 +85,7 @@ "acc = df_msg[\"acc\"]\n", "auc = df_msg[\"auc\"]\n", "\n", - "detect_metrics = ['''TPR@1%FPR''', \"acc\", \"auc\"]" + "detect_metrics = [\"\"\"TPR@1%FPR\"\"\", \"acc\", \"auc\"]" ] }, { @@ -107,37 +111,27 @@ "for metric in detect_metrics:\n", " # plt.subplot2grid((4, 4), pos, rowspan=2, colspan=2)\n", "\n", - " ax.plot(\n", - " steps, df_msg[metric], '--o', label=metric, markersize=markersize\n", - " )\n", + " ax.plot(steps, df_msg[metric], \"--o\", label=metric, markersize=markersize)\n", "\n", - "ax.plot(\n", - " steps, df_msg[\"Bit_acc\"], '-o', label=\"Bit acc\", markersize=markersize, linewidth=2.5\n", - ")\n", + "ax.plot(steps, df_msg[\"Bit_acc\"], \"-o\", label=\"Bit acc\", markersize=markersize, linewidth=2.5)\n", "\n", - "ax.plot(\n", - " steps, df_msg[\"Word_acc\"], '-o', label=\"Word acc\", markersize=markersize, linewidth=2.5\n", - ")\n", + "ax.plot(steps, df_msg[\"Word_acc\"], \"-o\", label=\"Word acc\", markersize=markersize, linewidth=2.5)\n", "\n", "\n", - "plt.xticks(\n", - " np.arange(0, 600, 100), fontsize=ticks_font\n", - ")\n", - "plt.yticks(\n", - " np.linspace(0, 1, 11), fontsize=ticks_font\n", - ")\n", + "plt.xticks(np.arange(0, 600, 100), fontsize=ticks_font)\n", + "plt.yticks(np.linspace(0, 1, 11), fontsize=ticks_font)\n", "\n", "plt.ylim(0.0, 1.03)\n", "\n", "ax.set_xlabel(\"steps\", fontsize=label_font)\n", "ax.legend(fontsize=legend_font)\n", - "ax.patch.set_edgecolor('black') \n", + "ax.patch.set_edgecolor(\"black\")\n", "ax.patch.set_linewidth(1)\n", "\n", "ax.grid(alpha=0.5)\n", "# --------\n", "\n", - "plt.savefig(\"detect_diff_metr.png\", bbox_inches='tight')\n", + "plt.savefig(\"detect_diff_metr.png\", bbox_inches=\"tight\")\n", "None" ] }, @@ -147,7 +141,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_detect = get_runs_df(\"new_big_diff_attacks\")[['''TPR@1%FPR''', \"acc\", \"auc\"]].iloc[1::2, :]" + "df_detect = get_runs_df(\"new_big_diff_attacks\")[[\"\"\"TPR@1%FPR\"\"\", \"acc\", \"auc\"]].iloc[1::2, :]" ] }, { @@ -167,16 +161,14 @@ } ], "source": [ - "detect_metrics = ['''TPR@1%FPR''', \"acc\", \"auc\"]\n", + "detect_metrics = [\"\"\"TPR@1%FPR\"\"\", \"acc\", \"auc\"]\n", "\n", "fig, ax = plt.subplots(figsize=(8, 6))\n", "\n", "\n", "for metric in detect_metrics:\n", "\n", - " ax.plot(\n", - " steps, df_detect[metric], '--o', label=metric, markersize=markersize\n", - " )\n", + " ax.plot(steps, df_detect[metric], \"--o\", label=metric, markersize=markersize)\n", "\n", "# ax.plot(\n", "# steps, df_msg[\"Bit_acc\"], '-o', label=\"Bit acc\", markersize=markersize, linewidth=2.5\n", @@ -187,25 +179,20 @@ "# )\n", "\n", "\n", - "\n", - "plt.xticks(\n", - " np.arange(0, 600, 100), fontsize=ticks_font\n", - ")\n", - "plt.yticks(\n", - " np.linspace(0, 1, 11), fontsize=ticks_font\n", - ")\n", + "plt.xticks(np.arange(0, 600, 100), fontsize=ticks_font)\n", + "plt.yticks(np.linspace(0, 1, 11), fontsize=ticks_font)\n", "\n", "plt.ylim(0.0, 1.03)\n", "\n", "ax.set_xlabel(\"steps\", fontsize=label_font)\n", "ax.legend(fontsize=legend_font)\n", - "ax.patch.set_edgecolor('black') \n", + "ax.patch.set_edgecolor(\"black\")\n", "ax.patch.set_linewidth(1)\n", "\n", "ax.grid(alpha=0.5)\n", "# --------\n", "\n", - "plt.savefig(\"detect_diff_treering_metr.png\", bbox_inches='tight')\n", + "plt.savefig(\"detect_diff_treering_metr.png\", bbox_inches=\"tight\")\n", "None" ] }, @@ -225,7 +212,9 @@ "source": [ "df_msg = get_runs_df(\"vae_attacks_metr\")\n", "\n", - "df_msg = df_msg.loc[:, [\"Word_acc\", \"Bit_acc\", \"vae_attack_quality\", '''TPR@1%FPR''', \"acc\", \"auc\"]].sort_values(by=\"vae_attack_quality\", ascending=False)" + "df_msg = df_msg.loc[:, [\"Word_acc\", \"Bit_acc\", \"vae_attack_quality\", \"\"\"TPR@1%FPR\"\"\", \"acc\", \"auc\"]].sort_values(\n", + " by=\"vae_attack_quality\", ascending=False\n", + ")" ] }, { @@ -239,7 +228,7 @@ "acc = df_msg[\"acc\"]\n", "auc = df_msg[\"auc\"]\n", "\n", - "detect_metrics = ['''TPR@1%FPR''', \"acc\", \"auc\"]" + "detect_metrics = [\"\"\"TPR@1%FPR\"\"\", \"acc\", \"auc\"]" ] }, { @@ -263,25 +252,15 @@ "\n", "for metric in detect_metrics:\n", "\n", - " ax.plot(\n", - " quality, df_msg[metric], '--o', label=metric, markersize=markersize\n", - " )\n", + " ax.plot(quality, df_msg[metric], \"--o\", label=metric, markersize=markersize)\n", "\n", - "ax.plot(\n", - " quality, df_msg[\"Bit_acc\"], '-o', label=\"Bit acc\", markersize=markersize, linewidth=2.5\n", - ")\n", + "ax.plot(quality, df_msg[\"Bit_acc\"], \"-o\", label=\"Bit acc\", markersize=markersize, linewidth=2.5)\n", "\n", - "ax.plot(\n", - " quality, df_msg[\"Word_acc\"], '-o', label=\"Word acc\", markersize=markersize, linewidth=2.5\n", - ")\n", + "ax.plot(quality, df_msg[\"Word_acc\"], \"-o\", label=\"Word acc\", markersize=markersize, linewidth=2.5)\n", "\n", "\n", - "plt.xticks(\n", - " np.linspace(1, 8, 8), fontsize=ticks_font\n", - ")\n", - "plt.yticks(\n", - " np.linspace(0.5, 1., 6), fontsize=ticks_font\n", - ")\n", + "plt.xticks(np.linspace(1, 8, 8), fontsize=ticks_font)\n", + "plt.yticks(np.linspace(0.5, 1.0, 6), fontsize=ticks_font)\n", "\n", "ax.grid(alpha=0.5)\n", "\n", @@ -290,7 +269,7 @@ "plt.legend(fontsize=legend_font)\n", "# --------\n", "\n", - "plt.savefig(\"detect_vae_metr.png\", bbox_inches='tight')" + "plt.savefig(\"detect_vae_metr.png\", bbox_inches=\"tight\")" ] }, { @@ -299,7 +278,7 @@ "metadata": {}, "outputs": [], "source": [ - "df_detect = get_runs_df(\"vae_2018_attacks\")[['''TPR@1%FPR''', \"acc\", \"auc\"]]" + "df_detect = get_runs_df(\"vae_2018_attacks\")[[\"\"\"TPR@1%FPR\"\"\", \"acc\", \"auc\"]]" ] }, { @@ -308,7 +287,7 @@ "metadata": {}, "outputs": [], "source": [ - "detect_metrics = ['''TPR@1%FPR''', \"acc\", \"auc\"]" + "detect_metrics = [\"\"\"TPR@1%FPR\"\"\", \"acc\", \"auc\"]" ] }, { @@ -332,9 +311,7 @@ "\n", "for metric in detect_metrics:\n", "\n", - " ax.plot(\n", - " quality, df_detect[metric], '--o', label=metric, markersize=markersize\n", - " )\n", + " ax.plot(quality, df_detect[metric], \"--o\", label=metric, markersize=markersize)\n", "\n", "# ax.plot(\n", "# quality, df_msg[\"Bit_acc\"], '-o', label=\"Bit acc\", markersize=markersize, linewidth=2.5\n", @@ -345,12 +322,8 @@ "# )\n", "\n", "\n", - "plt.xticks(\n", - " np.linspace(1, 8, 8), fontsize=ticks_font\n", - ")\n", - "plt.yticks(\n", - " np.linspace(0.5, 1., 6), fontsize=ticks_font\n", - ")\n", + "plt.xticks(np.linspace(1, 8, 8), fontsize=ticks_font)\n", + "plt.yticks(np.linspace(0.5, 1.0, 6), fontsize=ticks_font)\n", "\n", "ax.grid(alpha=0.5)\n", "\n", @@ -359,7 +332,7 @@ "plt.legend(fontsize=legend_font)\n", "# --------\n", "\n", - "plt.savefig(\"detect_vae_treering.png\", bbox_inches='tight')" + "plt.savefig(\"detect_vae_treering.png\", bbox_inches=\"tight\")" ] } ], diff --git a/plots/grid_search/msg_data_analysis.ipynb b/plots/grid_search/msg_data_analysis.ipynb index f1838b7..6a9ed03 100644 --- a/plots/grid_search/msg_data_analysis.ipynb +++ b/plots/grid_search/msg_data_analysis.ipynb @@ -14,6 +14,7 @@ "import os\n", "\n", "import matplotlib.pyplot as plt\n", + "\n", "sns.set_style(\"white\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})" ] }, @@ -23,7 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "pd.set_option('display.max_rows', 100)" + "pd.set_option(\"display.max_rows\", 100)" ] }, { @@ -32,14 +33,14 @@ "metadata": {}, "outputs": [], "source": [ - "from matplotlib.colors import LinearSegmentedColormap\n", + "from matplotlib.colors import LinearSegmentedColormap\n", "\n", - "c = [\"darkred\", \"firebrick\", \"red\", \"lightcoral\", \"palegreen\", \"lime\", \"green\",\"darkgreen\"]\n", + "c = [\"darkred\", \"firebrick\", \"red\", \"lightcoral\", \"palegreen\", \"lime\", \"green\", \"darkgreen\"]\n", "\n", - "v = [0,.15, 0.25,.4,0.65,0.85,.995,1.]\n", - "l = list(zip(v,c))\n", + "v = [0, 0.15, 0.25, 0.4, 0.65, 0.85, 0.995, 1.0]\n", + "l = list(zip(v, c))\n", "\n", - "CMAP = LinearSegmentedColormap.from_list('rg',l, N=256)" + "CMAP = LinearSegmentedColormap.from_list(\"rg\", l, N=256)" ] }, { @@ -57,10 +58,11 @@ "source": [ "# Чтобы выбрать долгие FID запуски надо: df[\"_runtime\"] > 7200\n", "\n", + "\n", "def get_runs_df(project, entity=\"jurujin\", resolution=False):\n", - " '''\n", + " \"\"\"\n", " Returns df with data from wandb project\n", - " '''\n", + " \"\"\"\n", " api = wandb.Api()\n", " runs = api.runs(entity + \"/\" + project)\n", "\n", @@ -70,7 +72,6 @@ " config_list.append({k: v for k, v in run.config.items() if not k.startswith(\"_\")})\n", " name_list.append(run.name)\n", "\n", - "\n", " summary_df = pd.DataFrame(summary_list)\n", " config_df = pd.DataFrame(config_list)\n", "\n", @@ -89,9 +90,9 @@ "outputs": [], "source": [ "def get_runs_df_stable_sig(project, entity=\"jurujin\"):\n", - " '''\n", + " \"\"\"\n", " Returns df with data from wandb project for stable-sig\n", - " '''\n", + " \"\"\"\n", " api = wandb.Api()\n", " runs = api.runs(entity + \"/\" + project)\n", "\n", @@ -101,7 +102,6 @@ " config_list.append({k: v for k, v in run.config.items() if not k.startswith(\"_\")})\n", " name_list.append(run.name)\n", "\n", - "\n", " summary_df = pd.DataFrame(summary_list)\n", " config_df = pd.DataFrame(config_list)\n", "\n", @@ -119,79 +119,89 @@ "source": [ "detection_projects = [\n", " \"msg_long_detect_no_att\",\n", - " \"msg_long_diff_att\", # V\n", - "\n", - " \"msg_grid_srch_vae\", # V\n", + " \"msg_long_diff_att\", # V\n", + " \"msg_grid_srch_vae\", # V\n", " \"msg_grid_srch_no_vae\",\n", - "\n", " \"detect_msg_all_att_vae\",\n", " \"detect_msg_all_att_no_vae\",\n", "]\n", "\n", - "stable_signature_detection_projects = [\n", - " \"eval_stable_tree_all_attacks\"\n", - "]\n", + "stable_signature_detection_projects = [\"eval_stable_tree_all_attacks\"]\n", "\n", "fid_projects = [\n", " \"fid_msg_grid_srch_gen_vae\",\n", - " \"fid_msg_grid_srch_gt_vae\", \n", - "\n", + " \"fid_msg_grid_srch_gt_vae\",\n", " \"fid_msg_grid_srch_gen_no_vae\",\n", - " \"fid_msg_grid_srch_gt_no_vae\", \n", - "\n", + " \"fid_msg_grid_srch_gt_no_vae\",\n", " \"fid_msg_r_gen\",\n", - " \"fid_msg_r_gt\"\n", - "\n", + " \"fid_msg_r_gt\",\n", "]\n", "\n", "fid_att_projects = [\n", " \"fid_gt_msg_all_att_vae\",\n", " \"fid_gt_msg_all_att_no_vae\",\n", - "\n", " \"fid_gen_msg_all_att_vae\",\n", - " \"fid_gen_msg_all_att_no_vae\"\n", + " \"fid_gen_msg_all_att_no_vae\",\n", "]\n", "\n", "detection_cols = [\n", " \"name\",\n", - " \"TPR@1%FPR\", \"acc\", \"auc\",\n", - " \"Bit_acc\", \"Word_acc\",\n", + " \"TPR@1%FPR\",\n", + " \"acc\",\n", + " \"auc\",\n", + " \"Bit_acc\",\n", + " \"Word_acc\",\n", " \"det_resol\",\n", - "\n", " \"w_clip_score_mean\",\n", - "\n", " \"w_det_dist_mean\",\n", " \"no_w_det_dist_mean\",\n", - "\n", " \"w_det_dist_std\",\n", " \"no_w_det_dist_std\",\n", - " \n", - " \"msg\", \"w_radius\", \"msg_scaler\",\n", - "\n", - " \"jpeg_ratio\", \"crop_scale\", \"crop_ratio\", \"gaussian_blur_r\", \"gaussian_std\", \"brightness_factor\", \"r_degree\"\n", + " \"msg\",\n", + " \"w_radius\",\n", + " \"msg_scaler\",\n", + " \"jpeg_ratio\",\n", + " \"crop_scale\",\n", + " \"crop_ratio\",\n", + " \"gaussian_blur_r\",\n", + " \"gaussian_std\",\n", + " \"brightness_factor\",\n", + " \"r_degree\",\n", "]\n", "\n", - "stable_signature_detection_cols = [\n", - " \"Bit_acc\", \"Word_acc\"\n", - "]\n", + "stable_signature_detection_cols = [\"Bit_acc\", \"Word_acc\"]\n", "\n", "fid_cols = [\n", " \"name\",\n", - " \"psnr_w\", \"ssim_w\",\n", - " \"psnr_no_w\", \"ssim_no_w\",\n", - " \"fid_w\", \"fid_no_w\",\n", - " \"msg\", \"w_radius\", \"msg_scaler\",\n", - "\n", + " \"psnr_w\",\n", + " \"ssim_w\",\n", + " \"psnr_no_w\",\n", + " \"ssim_no_w\",\n", + " \"fid_w\",\n", + " \"fid_no_w\",\n", + " \"msg\",\n", + " \"w_radius\",\n", + " \"msg_scaler\",\n", "]\n", "\n", "fid_att_cols = [\n", " \"name\",\n", - " \"psnr_w\", \"ssim_w\",\n", - " \"psnr_no_w\", \"ssim_no_w\",\n", - " \"fid_w\", \"fid_no_w\",\n", - " \"msg\", \"w_radius\", \"msg_scaler\",\n", - "\n", - " \"jpeg_ratio\", \"crop_scale\", \"crop_ratio\", \"gaussian_blur_r\", \"gaussian_std\", \"brightness_factor\", \"r_degree\"\n", + " \"psnr_w\",\n", + " \"ssim_w\",\n", + " \"psnr_no_w\",\n", + " \"ssim_no_w\",\n", + " \"fid_w\",\n", + " \"fid_no_w\",\n", + " \"msg\",\n", + " \"w_radius\",\n", + " \"msg_scaler\",\n", + " \"jpeg_ratio\",\n", + " \"crop_scale\",\n", + " \"crop_ratio\",\n", + " \"gaussian_blur_r\",\n", + " \"gaussian_std\",\n", + " \"brightness_factor\",\n", + " \"r_degree\",\n", "]" ] }, @@ -216,12 +226,16 @@ "\n", "for project in fid_projects:\n", " get_runs_df(project).to_csv(f\"./fid/{project}.csv\", index=False, columns=fid_cols)\n", - " \n", + "\n", "for project in fid_att_projects:\n", - " get_runs_df(project).sort_values(by=[\"jpeg_ratio\"], na_position='first').to_csv(f\"./fid/{project}.csv\", index=False, columns=fid_att_cols)\n", + " get_runs_df(project).sort_values(by=[\"jpeg_ratio\"], na_position=\"first\").to_csv(\n", + " f\"./fid/{project}.csv\", index=False, columns=fid_att_cols\n", + " )\n", "\n", "for project in stable_signature_detection_projects:\n", - " get_runs_df_stable_sig(project).to_csv(f\"./detection/{project}.csv\", index=False, columns=stable_signature_detection_cols)" + " get_runs_df_stable_sig(project).to_csv(\n", + " f\"./detection/{project}.csv\", index=False, columns=stable_signature_detection_cols\n", + " )" ] }, { @@ -247,11 +261,11 @@ " azim=20,\n", " elev=45,\n", " zoom=0.95,\n", - " reversed_cmap=False\n", - " ):\n", - " '''\n", + " reversed_cmap=False,\n", + "):\n", + " \"\"\"\n", " str metric_name: Name of the target column\n", - " '''\n", + " \"\"\"\n", " if reversed_cmap:\n", " cmap = CMAP.reversed()\n", " else:\n", @@ -259,9 +273,9 @@ "\n", " # df = get_runs_df(project)\n", " if \"fid\" in project.split(\"_\"):\n", - " df = pd.read_csv(f'./fid/{project}.csv')\n", + " df = pd.read_csv(f\"./fid/{project}.csv\")\n", " else:\n", - " df = pd.read_csv(f'./detection/{project}.csv')\n", + " df = pd.read_csv(f\"./detection/{project}.csv\")\n", "\n", " scaler = df[\"msg_scaler\"]\n", " radius = df[\"w_radius\"]\n", @@ -270,7 +284,6 @@ " scaler_num = scaler.nunique()\n", " radius_num = radius.nunique()\n", "\n", - " \n", " fig, ax = plt.subplots(subplot_kw={\"projection\": \"3d\"})\n", " fig.set_size_inches(figsize, forward=True)\n", "\n", @@ -282,8 +295,8 @@ "\n", " surf = ax.plot_surface(X, Y, Z, cmap=cmap, alpha=0.8)\n", "\n", - " ax.set_xlabel('Radius', fontsize=12)\n", - " ax.set_ylabel('Scaler', fontsize=12)\n", + " ax.set_xlabel(\"Radius\", fontsize=12)\n", + " ax.set_ylabel(\"Scaler\", fontsize=12)\n", " ax.set_zlabel(metric_name.replace(\"_\", \" \"), fontsize=12)\n", "\n", " ax.view_init(elev=elev, azim=azim)\n", @@ -295,7 +308,7 @@ " plt.yticks(np.arange(np.min(Y), np.max(Y) + y_ticks_step, y_ticks_step))\n", "\n", " # Add color bar\n", - " fig.colorbar(surf, shrink=0.4, aspect=8)\n" + " fig.colorbar(surf, shrink=0.4, aspect=8)" ] }, { @@ -315,13 +328,7 @@ } ], "source": [ - "plot_grid_search(\n", - " \"msg_grid_srch_vae\",\n", - " metric_name=\"Bit_acc\",\n", - " elev=30,\n", - " azim=45,\n", - " figsize=(20, 10)\n", - ")\n", + "plot_grid_search(\"msg_grid_srch_vae\", metric_name=\"Bit_acc\", elev=30, azim=45, figsize=(20, 10))\n", "\n", "# plt.savefig(\"bit_acc_grid.png\", bbox_inches='tight'))" ] @@ -333,14 +340,19 @@ "outputs": [], "source": [ "def grid_search_table(\n", - " project, metric_name, figsize=(10, 8),\n", - " use_title=True, reversed_cmap=False,\n", - " use_colorbar=True, float_signs=3,\n", - " fontsize_label=14, fontsize_ticks=12\n", - " ):\n", - " '''\n", + " project,\n", + " metric_name,\n", + " figsize=(10, 8),\n", + " use_title=True,\n", + " reversed_cmap=False,\n", + " use_colorbar=True,\n", + " float_signs=3,\n", + " fontsize_label=14,\n", + " fontsize_ticks=12,\n", + "):\n", + " \"\"\"\n", " str metric_name: Name of the target column\n", - " '''\n", + " \"\"\"\n", "\n", " if reversed_cmap:\n", " cmap = CMAP.reversed()\n", @@ -349,9 +361,9 @@ "\n", " # df = get_runs_df(project)\n", " if \"fid\" in project.split(\"_\"):\n", - " df = pd.read_csv(f'./fid/{project}.csv')\n", - " else:df = pd.read_csv(f'./detection/{project}.csv')\n", - "\n", + " df = pd.read_csv(f\"./fid/{project}.csv\")\n", + " else:\n", + " df = pd.read_csv(f\"./detection/{project}.csv\")\n", "\n", " scaler = df[\"msg_scaler\"]\n", " radius = df[\"w_radius\"]\n", @@ -360,7 +372,6 @@ " scaler_num = scaler.nunique()\n", " radius_num = radius.nunique()\n", "\n", - " \n", " fig, ax = plt.subplots()\n", " fig.set_size_inches(figsize, forward=True)\n", "\n", @@ -370,22 +381,22 @@ " Y = scaler.values.reshape(radius_num, scaler_num)\n", " Z = metric.values.reshape(radius_num, scaler_num)\n", "\n", - " ax.spines['top'].set_visible(False)\n", - " ax.spines['right'].set_visible(False)\n", - " ax.spines['bottom'].set_visible(False)\n", - " ax.spines['left'].set_visible(False)\n", + " ax.spines[\"top\"].set_visible(False)\n", + " ax.spines[\"right\"].set_visible(False)\n", + " ax.spines[\"bottom\"].set_visible(False)\n", + " ax.spines[\"left\"].set_visible(False)\n", "\n", " table = plt.pcolor(X, Y, Z, cmap=cmap, alpha=0.7)\n", "\n", " for i in range(radius_num):\n", " for j in range(scaler_num):\n", - " plt.text(X[i, j], Y[i, j], f'{Z[i, j]:.{float_signs}f}', ha='center', va='center', color='black')\n", + " plt.text(X[i, j], Y[i, j], f\"{Z[i, j]:.{float_signs}f}\", ha=\"center\", va=\"center\", color=\"black\")\n", "\n", " plt.xticks(radius.unique(), fontsize=fontsize_ticks)\n", " plt.yticks(scaler.unique(), fontsize=fontsize_ticks)\n", "\n", - " ax.set_xlabel('Radius', fontsize=fontsize_label)\n", - " ax.set_ylabel('Scaler', fontsize=fontsize_label)\n", + " ax.set_xlabel(\"Radius\", fontsize=fontsize_label)\n", + " ax.set_ylabel(\"Scaler\", fontsize=fontsize_label)\n", "\n", " if use_colorbar:\n", " fig.colorbar(table)\n", @@ -403,9 +414,9 @@ " name += \", Default VAE\"\n", " else:\n", " name += \", Stable Signature VAE\"\n", - " \n", + "\n", " if use_title:\n", - " fig.suptitle(f'{name}', x=0.43, fontsize=18, y=0.95)" + " fig.suptitle(f\"{name}\", x=0.43, fontsize=18, y=0.95)" ] }, { @@ -425,11 +436,7 @@ } ], "source": [ - "grid_search_table(\n", - " \"msg_grid_srch_vae\",\n", - " metric_name=\"Bit_acc\",\n", - " figsize=(14, 6)\n", - ")\n", + "grid_search_table(\"msg_grid_srch_vae\", metric_name=\"Bit_acc\", figsize=(14, 6))\n", "\n", "# plt.savefig(\"bit_acc_grid.png\", bbox_inches='tight')" ] @@ -451,11 +458,7 @@ } ], "source": [ - "grid_search_table(\n", - " \"msg_grid_srch_vae\",\n", - " metric_name=\"Word_acc\",\n", - " figsize=(14, 6)\n", - ")\n", + "grid_search_table(\"msg_grid_srch_vae\", metric_name=\"Word_acc\", figsize=(14, 6))\n", "\n", "# plt.savefig(\"bit_acc_grid.png\", bbox_inches='tight')" ] @@ -477,12 +480,7 @@ } ], "source": [ - "grid_search_table(\n", - " \"fid_msg_grid_srch_gen_vae\",\n", - " metric_name=\"fid_w\",\n", - " figsize=(14, 6),\n", - " reversed_cmap=True\n", - ")\n", + "grid_search_table(\"fid_msg_grid_srch_gen_vae\", metric_name=\"fid_w\", figsize=(14, 6), reversed_cmap=True)\n", "\n", "# plt.savefig(\"bit_acc_grid.png\", bbox_inches='tight')" ] @@ -504,12 +502,7 @@ } ], "source": [ - "grid_search_table(\n", - " \"fid_msg_grid_srch_gt_vae\",\n", - " metric_name=\"fid_w\",\n", - " figsize=(14, 6),\n", - " reversed_cmap=True\n", - ")\n", + "grid_search_table(\"fid_msg_grid_srch_gt_vae\", metric_name=\"fid_w\", figsize=(14, 6), reversed_cmap=True)\n", "\n", "# plt.savefig(\"bit_acc_grid.png\", bbox_inches='tight')" ] @@ -546,10 +539,10 @@ " use_title=False,\n", " use_colorbar=False,\n", " fontsize_label=16,\n", - " fontsize_ticks=12\n", + " fontsize_ticks=12,\n", ")\n", "\n", - "plt.savefig(\"bit_acc_grid_default.png\", bbox_inches='tight')" + "plt.savefig(\"bit_acc_grid_default.png\", bbox_inches=\"tight\")" ] }, { @@ -576,10 +569,10 @@ " use_title=False,\n", " use_colorbar=False,\n", " fontsize_label=16,\n", - " fontsize_ticks=12\n", + " fontsize_ticks=12,\n", ")\n", "\n", - "plt.savefig(\"word_acc_grid_default.png\", bbox_inches='tight')" + "plt.savefig(\"word_acc_grid_default.png\", bbox_inches=\"tight\")" ] }, { @@ -608,10 +601,10 @@ " float_signs=1,\n", " use_colorbar=False,\n", " fontsize_label=16,\n", - " fontsize_ticks=12\n", + " fontsize_ticks=12,\n", ")\n", "\n", - "plt.savefig(\"fid_gen_grid_default.png\", bbox_inches='tight')" + "plt.savefig(\"fid_gen_grid_default.png\", bbox_inches=\"tight\")" ] }, { @@ -640,10 +633,10 @@ " float_signs=1,\n", " use_colorbar=False,\n", " fontsize_label=16,\n", - " fontsize_ticks=12\n", + " fontsize_ticks=12,\n", ")\n", "\n", - "plt.savefig(\"fid_gt_grid_default.png\", bbox_inches='tight')" + "plt.savefig(\"fid_gt_grid_default.png\", bbox_inches=\"tight\")" ] }, { @@ -670,10 +663,10 @@ " use_title=False,\n", " use_colorbar=False,\n", " fontsize_label=16,\n", - " fontsize_ticks=12\n", + " fontsize_ticks=12,\n", ")\n", "\n", - "plt.savefig(\"det_resol_grid_default.png\", bbox_inches='tight')" + "plt.savefig(\"det_resol_grid_default.png\", bbox_inches=\"tight\")" ] }, { @@ -700,10 +693,10 @@ " use_title=False,\n", " use_colorbar=False,\n", " fontsize_label=16,\n", - " fontsize_ticks=12\n", + " fontsize_ticks=12,\n", ")\n", "\n", - "plt.savefig(\"w_clip_grid_default.png\", bbox_inches='tight')" + "plt.savefig(\"w_clip_grid_default.png\", bbox_inches=\"tight\")" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 301f8c6..6b91424 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,29 +1,11 @@ [metadata] -name = "py_template" -author = "deepvk" -url = "https://github.com/deepvk/py_template" +name = "METR" +author = "Alexander Varlamov, VK Lab" +url = "https://github.com/deepvk/metr" [tool.isort] profile = "black" line_length = 120 [tool.black] -line-length = 120 - -[tool.mypy] -files = ["src"] -install_types = "True" -non_interactive = "True" -disallow_untyped_defs = "True" -ignore_missing_imports = "True" -show_error_codes = "True" -warn_redundant_casts = "True" -warn_unused_configs = "True" -warn_unused_ignores = "True" -allow_redefinition = "True" -warn_no_return = "False" -no_implicit_optional = "False" - -[tool.pytest.ini_options] -testpaths = ["tests"] -addopts = ["--color=yes", "-s"] \ No newline at end of file +line-length = 120 \ No newline at end of file diff --git a/requirements.dev.txt b/requirements.dev.txt index 8cf9836..9c609bd 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,34 +1,2 @@ --e . --e ./WatermarkAttacker - -torch==2.1.2 -torchvision==0.16.2 -transformers==4.31.0 -diffusers==0.14.0 -accelerate==0.26.1 -xformers==0.0.23.post1 - -# Stable-Signature dependencies: -einops==0.3.0 -open_clip_torch==2.0.2 -torchmetrics==1.3.0.post0 -augly==1.0.0 -pytorch-fid==0.3.0 -pytorch-lightning==2.1.3 - -# WM-Attacker dependencies: -wandb -datasets -ftfy -omegaconf -opencv-python -scikit-image -bm3d -compressai -torch_fidelity -onnxruntime - -# Development: -black -isort -invisible-watermark \ No newline at end of file +black==24.4.2 +isort==5.13.2 \ No newline at end of file diff --git a/src/metr/finetune_ldm_decoder.py b/src/metr/finetune_ldm_decoder.py index 74f8eb8..eec67d3 100644 --- a/src/metr/finetune_ldm_decoder.py +++ b/src/metr/finetune_ldm_decoder.py @@ -28,7 +28,6 @@ # import .stable_sig.utils_model - def import_from_stable_sig(name): module = importlib.import_module(".stable_sig." + name, package=__package__) return module @@ -38,9 +37,8 @@ def import_from_stable_sig(name): utils_img = import_from_stable_sig("utils_img") utils_model = import_from_stable_sig("utils_model") -from tqdm import tqdm - import wandb +from tqdm import tqdm # sys.path.append('src') from .ldm.models.autoencoder import AutoencoderKL diff --git a/src/metr/metr_pp_eval_stable_sig.py b/src/metr/metr_pp_eval_stable_sig.py index 7e7a4fd..588e8a6 100644 --- a/src/metr/metr_pp_eval_stable_sig.py +++ b/src/metr/metr_pp_eval_stable_sig.py @@ -27,7 +27,6 @@ # import utils_model - def import_from_stable_sig(name): module = importlib.import_module(".stable_sig." + name, package=__package__) return module @@ -37,11 +36,10 @@ def import_from_stable_sig(name): utils_img = import_from_stable_sig("utils_img") utils_model = import_from_stable_sig("utils_model") +import wandb from wm_attacks import ReSDPipeline from wm_attacks.wmattacker_no_saving import DiffWMAttacker, VAEWMAttacker -import wandb - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/src/metr/run_metr.py b/src/metr/run_metr.py index aaefcb4..d64169e 100644 --- a/src/metr/run_metr.py +++ b/src/metr/run_metr.py @@ -12,14 +12,13 @@ import numpy as np import PIL import torch +import wandb from diffusers import DPMSolverMultistepScheduler from sklearn import metrics from tqdm import tqdm from wm_attacks import ReSDPipeline from wm_attacks.wmattacker_no_saving import DiffWMAttacker, VAEWMAttacker -import wandb - from .inverse_stable_diffusion import InversableStableDiffusionPipeline from .io_utils import * from .open_clip import create_model_and_transforms, get_tokenizer @@ -30,11 +29,6 @@ # ------------ - - - - - def main(args): if args.save_locally: if not os.path.exists(args.local_path) and not os.path.exists(args.local_path + f"/imgs_no_w/"): diff --git a/src/metr/run_metr_fid.py b/src/metr/run_metr_fid.py index 10a306f..9064d11 100644 --- a/src/metr/run_metr_fid.py +++ b/src/metr/run_metr_fid.py @@ -9,6 +9,7 @@ import numpy as np import PIL import torch +import wandb from diffusers import DPMSolverMultistepScheduler from PIL import Image, ImageFile from pytorch_msssim import ssim @@ -16,8 +17,6 @@ from wm_attacks import ReSDPipeline from wm_attacks.wmattacker_with_saving import DiffWMAttacker, VAEWMAttacker -import wandb - from .inverse_stable_diffusion import InversableStableDiffusionPipeline from .io_utils import * from .optim_utils import *