Skip to content

Commit

Permalink
Update CI, fix black and isort issues
Browse files Browse the repository at this point in the history
  • Loading branch information
SpirinEgor committed Jul 8, 2024
1 parent 947a16e commit fec64c0
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 343 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@ jobs:
cache: "pip"
- name: "installation"
run: |
pip install -r requirements.txt -r requirements.dev.txt
pip install -r requirements.dev.txt
- name: "black"
run: black . --check --diff --color
- name: "isort"
run: isort . --check --diff
- name: "mypy"
run: mypy
- name: "pytests"
run: pytest
105 changes: 63 additions & 42 deletions plots/create_csvs/for_paper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"import os\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"sns.set_style(\"white\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})"
]
},
Expand All @@ -23,10 +24,10 @@
"metadata": {},
"outputs": [],
"source": [
"def get_runs_df_stable_sig(project, entity=\"jurujin\", runtime_limit=6*3600):\n",
" '''\n",
"def get_runs_df_stable_sig(project, entity=\"jurujin\", runtime_limit=6 * 3600):\n",
" \"\"\"\n",
" Returns df with data from wandb project for stable-sig\n",
" '''\n",
" \"\"\"\n",
" api = wandb.Api()\n",
" runs = api.runs(entity + \"/\" + project)\n",
"\n",
Expand All @@ -36,7 +37,6 @@
" config_list.append({k: v for k, v in run.config.items() if not k.startswith(\"_\")})\n",
" name_list.append(run.name)\n",
"\n",
"\n",
" summary_df = pd.DataFrame(summary_list)\n",
" config_df = pd.DataFrame(config_list)\n",
"\n",
Expand All @@ -55,10 +55,11 @@
"source": [
"# Чтобы выбрать долгие FID запуски надо: df[\"_runtime\"] > 7200\n",
"\n",
"def get_runs_df(project, entity=\"jurujin\", runtime_limit=6*3600, resolution=False):\n",
" '''\n",
"\n",
"def get_runs_df(project, entity=\"jurujin\", runtime_limit=6 * 3600, resolution=False):\n",
" \"\"\"\n",
" Returns df with data from wandb project\n",
" '''\n",
" \"\"\"\n",
" df = get_runs_df_stable_sig(project, entity, runtime_limit)\n",
"\n",
" if resolution:\n",
Expand All @@ -75,67 +76,79 @@
"source": [
"detection_projects = [\n",
" \"detect_msg_all_att_vae\",\n",
" \"detect_msg_all_att_no_vae\", \n",
"\n",
" \"detect_msg_all_att_no_vae\",\n",
" # \"clip_different_msg\" # Testing CLIP quality for different message\n",
"]\n",
"\n",
"stable_signature_detection_projects = [\n",
" \"eval_stable_tree_all_attacks\"\n",
"]\n",
"stable_signature_detection_projects = [\"eval_stable_tree_all_attacks\"]\n",
"\n",
"fid_projects = [\n",
" # \"fid_gt_msg_all_att_vae\",\n",
" # \"fid_gt_msg_all_att_no_vae\",\n",
"\n",
" # \"fid_gen_msg_all_att_vae\",\n",
" # \"fid_gen_msg_all_att_no_vae\",\n",
"\n",
" \"fid_gen_message_dependency\",\n",
" \"fid_gt_message_dependency\",\n",
" \n",
"]\n",
"\n",
"detection_cols = [\n",
" \"name\",\n",
" \"TPR@1%FPR\", \"auc\", \"acc\",\n",
" \"Bit_acc\", \"Word_acc\",\n",
" \"TPR@1%FPR\",\n",
" \"auc\",\n",
" \"acc\",\n",
" \"Bit_acc\",\n",
" \"Word_acc\",\n",
" \"det_resol\",\n",
"\n",
" \"w_clip_score_mean\",\n",
"\n",
" \"w_det_dist_mean\",\n",
" \"no_w_det_dist_mean\",\n",
"\n",
" \"w_det_dist_std\",\n",
" \"no_w_det_dist_std\",\n",
" \n",
" \"msg\", \"w_radius\", \"msg_scaler\",\n",
"\n",
" \"jpeg_ratio\", \"crop_scale\", \"crop_ratio\", \"gaussian_blur_r\", \"gaussian_std\", \"brightness_factor\", \"r_degree\"\n",
" \"msg\",\n",
" \"w_radius\",\n",
" \"msg_scaler\",\n",
" \"jpeg_ratio\",\n",
" \"crop_scale\",\n",
" \"crop_ratio\",\n",
" \"gaussian_blur_r\",\n",
" \"gaussian_std\",\n",
" \"brightness_factor\",\n",
" \"r_degree\",\n",
"]\n",
"\n",
"stable_signature_detection_cols = [\n",
" \"name\", \"Bit_acc\", \"Word_acc\"\n",
"]\n",
"stable_signature_detection_cols = [\"name\", \"Bit_acc\", \"Word_acc\"]\n",
"\n",
"fid_cols = [\n",
" \"name\",\n",
" \"psnr_w\", \"ssim_w\",\n",
" \"psnr_no_w\", \"ssim_no_w\",\n",
" \"fid_w\", \"fid_no_w\",\n",
" \"msg\", \"w_radius\", \"msg_scaler\",\n",
"\n",
" \"psnr_w\",\n",
" \"ssim_w\",\n",
" \"psnr_no_w\",\n",
" \"ssim_no_w\",\n",
" \"fid_w\",\n",
" \"fid_no_w\",\n",
" \"msg\",\n",
" \"w_radius\",\n",
" \"msg_scaler\",\n",
"]\n",
"\n",
"fid_att_cols = [\n",
" \"name\",\n",
" \"psnr_w\", \"ssim_w\",\n",
" \"psnr_no_w\", \"ssim_no_w\",\n",
" \"fid_w\", \"fid_no_w\",\n",
" \"msg\", \"w_radius\", \"msg_scaler\",\n",
"\n",
" \"jpeg_ratio\", \"crop_scale\", \"crop_ratio\", \"gaussian_blur_r\", \"gaussian_std\", \"brightness_factor\", \"r_degree\"\n",
" \"psnr_w\",\n",
" \"ssim_w\",\n",
" \"psnr_no_w\",\n",
" \"ssim_no_w\",\n",
" \"fid_w\",\n",
" \"fid_no_w\",\n",
" \"msg\",\n",
" \"w_radius\",\n",
" \"msg_scaler\",\n",
" \"jpeg_ratio\",\n",
" \"crop_scale\",\n",
" \"crop_ratio\",\n",
" \"gaussian_blur_r\",\n",
" \"gaussian_std\",\n",
" \"brightness_factor\",\n",
" \"r_degree\",\n",
"]"
]
},
Expand All @@ -147,7 +160,9 @@
"source": [
"for project in detection_projects:\n",
" os.makedirs(\"./detection\", exist_ok=True)\n",
" get_runs_df(project, resolution=True, runtime_limit=4 * 3600).to_csv(f\"./detection/{project}.csv\", index=False, columns=detection_cols)"
" get_runs_df(project, resolution=True, runtime_limit=4 * 3600).to_csv(\n",
" f\"./detection/{project}.csv\", index=False, columns=detection_cols\n",
" )"
]
},
{
Expand All @@ -158,7 +173,9 @@
"source": [
"for project in fid_projects:\n",
" os.makedirs(\"./fid\", exist_ok=True)\n",
" get_runs_df(project).sort_values(by=\"name\", ascending=False).to_csv(f\"./fid/{project}.csv\", index=False, columns=fid_cols)"
" get_runs_df(project).sort_values(by=\"name\", ascending=False).to_csv(\n",
" f\"./fid/{project}.csv\", index=False, columns=fid_cols\n",
" )"
]
},
{
Expand All @@ -169,7 +186,9 @@
"source": [
"for project in stable_signature_detection_projects:\n",
" os.makedirs(\"./detection\", exist_ok=True)\n",
" get_runs_df_stable_sig(project, runtime_limit=0).to_csv(f\"./detection/{project}.csv\", index=False, columns=stable_signature_detection_cols)"
" get_runs_df_stable_sig(project, runtime_limit=0).to_csv(\n",
" f\"./detection/{project}.csv\", index=False, columns=stable_signature_detection_cols\n",
" )"
]
},
{
Expand All @@ -180,7 +199,9 @@
"source": [
"clip_different_message = \"clip_different_msg\"\n",
"os.makedirs(\"./detection\", exist_ok=True)\n",
"get_runs_df(clip_different_message, runtime_limit=0, resolution=True).sort_values(by=\"name\", ascending=False).to_csv(f\"./detection/{clip_different_message}.csv\", index=False, columns=detection_cols)"
"get_runs_df(clip_different_message, runtime_limit=0, resolution=True).sort_values(by=\"name\", ascending=False).to_csv(\n",
" f\"./detection/{clip_different_message}.csv\", index=False, columns=detection_cols\n",
")"
]
}
],
Expand Down
58 changes: 24 additions & 34 deletions plots/fid_dispersion/fid_dispersion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"sns.set_style(\"white\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})"
]
},
Expand All @@ -22,10 +23,10 @@
"metadata": {},
"outputs": [],
"source": [
"def get_runs_df_stable_sig(project, entity=\"jurujin\", runtime_limit=6*3600):\n",
" '''\n",
"def get_runs_df_stable_sig(project, entity=\"jurujin\", runtime_limit=6 * 3600):\n",
" \"\"\"\n",
" Returns df with data from wandb project for stable-sig\n",
" '''\n",
" \"\"\"\n",
" api = wandb.Api()\n",
" runs = api.runs(entity + \"/\" + project)\n",
"\n",
Expand All @@ -35,7 +36,6 @@
" config_list.append({k: v for k, v in run.config.items() if not k.startswith(\"_\")})\n",
" name_list.append(run.name)\n",
"\n",
"\n",
" summary_df = pd.DataFrame(summary_list)\n",
" config_df = pd.DataFrame(config_list)\n",
"\n",
Expand All @@ -45,12 +45,14 @@
"\n",
" return df\n",
"\n",
"\n",
"# Чтобы выбрать долгие FID запуски надо: df[\"_runtime\"] > 7200\n",
"\n",
"def get_runs_df(project, entity=\"jurujin\", runtime_limit=6*3600, resolution=False):\n",
" '''\n",
"\n",
"def get_runs_df(project, entity=\"jurujin\", runtime_limit=6 * 3600, resolution=False):\n",
" \"\"\"\n",
" Returns df with data from wandb project\n",
" '''\n",
" \"\"\"\n",
" df = get_runs_df_stable_sig(project, entity, runtime_limit)\n",
"\n",
" if resolution:\n",
Expand All @@ -69,24 +71,15 @@
" # Все месседжи кроме 00..0 для всех S кроме 100\n",
" \"fid_gеn_message_s\",\n",
" \"fid_gt_message_s\",\n",
"\n",
" # Все месседжи для S=100\n",
" \"fid_gen_message_dependency\",\n",
" \"fid_gt_message_dependency\",\n",
"\n",
" # Месседжи 00..0 кроме S=100\n",
" \"worst_message_fid_gen\",\n",
" \"worst_message_fid_gt\"\n",
" \"worst_message_fid_gt\",\n",
"]\n",
"\n",
"fid_cols = [\n",
" \"name\",\n",
" \"msg\",\n",
" \"msg_scaler\",\n",
" \"fid_w\",\n",
" \"target_clean_generated\"\n",
"\n",
"]"
"fid_cols = [\"name\", \"msg\", \"msg_scaler\", \"fid_w\", \"target_clean_generated\"]"
]
},
{
Expand Down Expand Up @@ -168,12 +161,9 @@
" stds = [df[df[\"msg_scaler\"] == S][\"fid_w\"].std() for S in range(60, 110, 10)]\n",
" means = [df[df[\"msg_scaler\"] == S][\"fid_w\"].mean() for S in range(60, 110, 10)]\n",
"\n",
" stats_dict = {\n",
" \"mean\":means,\n",
" \"std\":stds\n",
" }\n",
" \n",
" return pd.DataFrame(data=stats_dict, index=range(60, 110, 10))\n"
" stats_dict = {\"mean\": means, \"std\": stds}\n",
"\n",
" return pd.DataFrame(data=stats_dict, index=range(60, 110, 10))"
]
},
{
Expand All @@ -200,7 +190,7 @@
"metadata": {},
"outputs": [],
"source": [
"markersize=1\n",
"markersize = 1\n",
"ticks_font = 16\n",
"label_font = 20\n",
"legend_font = 20"
Expand All @@ -226,8 +216,10 @@
"fig, axs = plt.subplots(2, 1, figsize=(8, 6))\n",
"\n",
"S = range(60, 110, 10)\n",
"axs[0].errorbar(S, gt_stat[\"mean\"], yerr=gt_stat[\"std\"], label='FID gt', fmt='-o', markersize=markersize, c=\"#1f77b4\")\n",
"axs[1].errorbar(S, gen_stat[\"mean\"], yerr=gen_stat[\"std\"], label='FID gen', fmt='-o', markersize=markersize, c=\"#ff7f0e\")\n",
"axs[0].errorbar(S, gt_stat[\"mean\"], yerr=gt_stat[\"std\"], label=\"FID gt\", fmt=\"-o\", markersize=markersize, c=\"#1f77b4\")\n",
"axs[1].errorbar(\n",
" S, gen_stat[\"mean\"], yerr=gen_stat[\"std\"], label=\"FID gen\", fmt=\"-o\", markersize=markersize, c=\"#ff7f0e\"\n",
")\n",
"\n",
"axs[0].set_ylim(25, 30.5)\n",
"axs[0].set_yticks(np.arange(25, 31, 1))\n",
Expand All @@ -236,19 +228,17 @@
"axs[1].set_yticks(np.arange(9, 19, 2))\n",
"\n",
"for ax in axs:\n",
" ax.legend(loc=\"upper left\" ,fontsize=legend_font)\n",
" ax.patch.set_edgecolor('black') \n",
" ax.legend(loc=\"upper left\", fontsize=legend_font)\n",
" ax.patch.set_edgecolor(\"black\")\n",
" ax.patch.set_linewidth(1)\n",
" ax.grid(alpha=0.5)\n",
" ax.set_xlim(58, 102)\n",
" ax.set_xticks(\n",
" np.arange(60, 110, 10)\n",
" )\n",
" ax.tick_params(axis='both', which='major', labelsize=ticks_font)\n",
" ax.set_xticks(np.arange(60, 110, 10))\n",
" ax.tick_params(axis=\"both\", which=\"major\", labelsize=ticks_font)\n",
"\n",
"ax.set_xlabel(\"$S$ hyperparameter\", fontsize=label_font)\n",
"\n",
"plt.savefig(\"fid_dispersion.png\", bbox_inches='tight')"
"plt.savefig(\"fid_dispersion.png\", bbox_inches=\"tight\")"
]
}
],
Expand Down
Loading

0 comments on commit fec64c0

Please sign in to comment.