Skip to content

Commit

Permalink
ring-id implementation ideas
Browse files Browse the repository at this point in the history
  • Loading branch information
Alphonsce committed Jul 29, 2024
1 parent b472415 commit fdf70af
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 66 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@ eval_st_sig_logs

requirements.test.txt

*.egg-info
*.egg-info

test.sh
/tests
66 changes: 33 additions & 33 deletions plots/grid_search/msg_data_analysis.ipynb

Large diffs are not rendered by default.

33 changes: 2 additions & 31 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,34 +1,5 @@
-e .
-e ./WatermarkAttacker

torch==2.1.2
torchvision==0.16.2
transformers==4.31.0
diffusers==0.14.0
accelerate==0.26.1
xformers==0.0.23.post1

# Stable-Signature dependencies:
einops==0.3.0
open_clip_torch==2.0.2
torchmetrics==1.3.0.post0
augly==1.0.0
pytorch-fid==0.3.0
pytorch-lightning==2.1.3

# WM-Attacker dependencies:
wandb
datasets
ftfy
omegaconf
opencv-python
scikit-image
bm3d
compressai
torch_fidelity
onnxruntime

# Development:
black
isort
invisible-watermark
invisible-watermark
jupyter
22 changes: 22 additions & 0 deletions scripts/ring-id/reprod.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
while getopts "r:s:" opt; do
case $opt in
r) r="$OPTARG"
;;
s) s="$OPTARG"
;;
\?) echo "Invalid option -$OPTARG" >&2
;;
esac
done

accelerate launch -m metr.run_metr \
--project_name metr-detect-no-att \
--run_name s=$s-r=$r --w_channel 3 --w_pattern ring \
--start 0 --end 1000 \
--reference_model ViT-g-14 --reference_model_pretrain laion2b_s12b_b42k \
--with_tracking \
--w_radius $r \
--msg_type binary \
--use_random_msgs \
--msg_scaler $s \
--no_stable_sig
37 changes: 37 additions & 0 deletions src/metr/optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,33 @@ def circle_mask(size=64, r=10, x_offset=0, y_offset=0):

return ((x - x0) ** 2 + (y - y0) ** 2) <= r**2

def rotate_pixel(size, r):
image = np.zeros((size, size), dtype=np.uint8)
center = size // 2
one_pixel = np.zeros_like(image)
one_pixel[center + r][center] = 1

for rot_angle in np.linspace(0, 360, 360):
pil_pixel = Image.fromarray(one_pixel)
rot_image = np.array(pil_pixel.rotate(rot_angle, center=(center + 0.5, center + 0.5)))
image = image + rot_image

return image.astype(bool).astype(int)

def ring_id_mask(size, r):
mask_full = np.zeros([size, size])
for radius in range(1, r + 1):
mask_full += rotate_pixel(radius, size)

return mask_full.astype(bool)

def ring_id_gt_patch(size, r, args=None):
gt_patch = np.zeros([size, size])

for radius in range(1, r + 1):
gt_patch += rotate_pixel(radius, size) * args.msg_scaler * (-1) ** radius
return gt_patch


def get_watermarking_mask(init_latents_w, args, device):
watermarking_mask = torch.zeros(init_latents_w.shape, dtype=torch.bool).to(device)
Expand All @@ -139,6 +166,16 @@ def get_watermarking_mask(init_latents_w, args, device):
watermarking_mask[:, :] = torch_mask
else:
watermarking_mask[:, args.w_channel] = torch_mask

if args.w_mask_shape == "ring-id":
np_mask = ring_id_mask(init_latents_w.shape[-1], r=args.w_radius)
torch_mask = torch.tensor(np_mask).to(device)
if args.w_channel == -1:
# all channels
watermarking_mask[:, :] = torch_mask
else:
watermarking_mask[:, args.w_channel] = torch_mask

elif args.w_mask_shape == "square":
anchor_p = init_latents_w.shape[-1] // 2
if args.w_channel == -1:
Expand Down
2 changes: 1 addition & 1 deletion src/metr/run_metr.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def main(args):
parser.add_argument("--w_seed", default=999999, type=int)
parser.add_argument("--w_channel", default=3, type=int)
parser.add_argument("--w_pattern", default="ring")
parser.add_argument("--w_mask_shape", default="circle")
parser.add_argument("--w_mask_shape", default="circle", help="Can be 'ring-id' or 'circle' for default METR / Tree-Ring")
parser.add_argument("--w_radius", default=10, type=int)
parser.add_argument("--w_measurement", default="l1_complex")
parser.add_argument("--w_injection", default="complex")
Expand Down

0 comments on commit fdf70af

Please sign in to comment.