forked from jacobgil/vit-explain
-
Notifications
You must be signed in to change notification settings - Fork 1
/
vit_explain.py
83 lines (73 loc) · 3.03 KB
/
vit_explain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import sys
import torch
from PIL import Image
from torchvision import transforms
import numpy as np
import cv2
from vit_rollout import VITAttentionRollout
from vit_grad_rollout import VITAttentionGradRollout
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--use_cuda', action='store_true', default=False,
help='Use NVIDIA GPU acceleration')
parser.add_argument('--image_path', type=str, default='./examples/both.png',
help='Input image path')
parser.add_argument('--head_fusion', type=str, default='max',
help='How to fuse the attention heads for attention rollout. \
Can be mean/max/min')
parser.add_argument('--discard_ratio', type=float, default=0.9,
help='How many of the lowest 14x14 attention paths should we discard')
parser.add_argument('--category_index', type=int, default=None,
help='The category index for gradient rollout')
args = parser.parse_args()
args.use_cuda = args.use_cuda and torch.cuda.is_available()
if args.use_cuda:
print("Using GPU")
else:
print("Using CPU")
return args
def show_mask_on_image(img, mask):
img = np.float32(img) / 255
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return np.uint8(255 * cam)
if __name__ == '__main__':
args = get_args()
model = torch.hub.load('facebookresearch/deit:main',
'deit_tiny_patch16_224', pretrained=True)
model.eval()
if args.use_cuda:
model = model.cuda()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
img = Image.open(args.image_path)
img = img.resize((224, 224))
input_tensor = transform(img).unsqueeze(0)
if args.use_cuda:
input_tensor = input_tensor.cuda()
if args.category_index is None:
print("Doing Attention Rollout")
attention_rollout = VITAttentionRollout(model, head_fusion=args.head_fusion,
discard_ratio=args.discard_ratio)
mask = attention_rollout(input_tensor)
name = "attention_rollout_{:.3f}_{}.png".format(args.discard_ratio, args.head_fusion)
else:
print("Doing Gradient Attention Rollout")
grad_rollout = VITAttentionGradRollout(model, discard_ratio=args.discard_ratio)
mask = grad_rollout(input_tensor, args.category_index)
name = "grad_rollout_{}_{:.3f}_{}.png".format(args.category_index,
args.discard_ratio, args.head_fusion)
np_img = np.array(img)[:, :, ::-1]
mask = cv2.resize(mask, (np_img.shape[1], np_img.shape[0]))
mask = show_mask_on_image(np_img, mask)
cv2.imshow("Input Image", np_img)
cv2.imshow(name, mask)
cv2.imwrite("input.png", np_img)
cv2.imwrite(name, mask)
cv2.waitKey(-1)