Replies: 10 comments 28 replies
-
Hi @kiashann This is toy examples to visualize whole attention map and attention map only for class token. (see here for more information) import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from timm.models import create_model
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor
def to_tensor(img):
transform_fn = Compose([Resize(249, 3), CenterCrop(224), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
return transform_fn(img)
def show_img(img):
img = np.asarray(img)
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.axis('off')
plt.show()
def show_img2(img1, img2, alpha=0.8):
img1 = np.asarray(img1)
img2 = np.asarray(img2)
plt.figure(figsize=(10, 10))
plt.imshow(img1)
plt.imshow(img2, alpha=alpha)
plt.axis('off')
plt.show()
def my_forward_wrapper(attn_obj):
def my_forward(x):
B, N, C = x.shape
qkv = attn_obj.qkv(x).reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
attn = attn.softmax(dim=-1)
attn = attn_obj.attn_drop(attn)
attn_obj.attn_map = attn
attn_obj.cls_attn_map = attn[:, :, 0, 2:]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = attn_obj.proj(x)
x = attn_obj.proj_drop(x)
return x
return my_forward
img = Image.open('n02102480_Sussex_spaniel.JPEG')
x = to_tensor(img)
model = create_model('deit_small_distilled_patch16_224', pretrained=True)
model.blocks[-1].attn.forward = my_forward_wrapper(model.blocks[-1].attn)
y = model(x.unsqueeze(0))
attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()
cls_weight = model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(14, 14).detach()
img_resized = x.permute(1, 2, 0) * 0.5 + 0.5
cls_resized = F.interpolate(cls_weight.view(1, 1, 14, 14), (224, 224), mode='bilinear').view(224, 224, 1)
show_img(img)
show_img(attn_map)
show_img(cls_weight)
show_img(img_resized)
show_img2(img_resized, cls_resized, alpha=0.8) attention map for last layer (198 x 198 (=196(img) + 1(cls) + 1(distill))) |
Beta Was this translation helpful? Give feedback.
-
@kiashann |
Beta Was this translation helpful? Give feedback.
-
I would like to apply this code to the 'vit_small_patch16_384' model from timm. How should I modify the code for this purpose?
|
Beta Was this translation helpful? Give feedback.
-
So, this doesn't include the visualization helpers yet, but have added a simpler extraction helper to get the attention activations via one of two methods, fx or hooks. WIP but can be seen https://github.com/huggingface/pytorch-image-models/pull/2168/files#diff-358e0d5feb2c109ff53d21bc4fa8a6af94566be622b0f1167316216b0036b8b3
|
Beta Was this translation helpful? Give feedback.
-
FYI there's a fix on main for the node/module matching so that outputs will remain in order of traversal (usually matches order of forward pass, at least for timm models) regardless of how many matching names/wildcards are specified. |
Beta Was this translation helpful? Give feedback.
-
Hi! @hankyul2 Thanks for your excellent explanation above. I understood most of it but was still confused about Why Thanks! |
Beta Was this translation helpful? Give feedback.
-
https://github.com/facebookresearch/dino/blob/main/visualize_attention.py this might be helpful |
Beta Was this translation helpful? Give feedback.
-
Thank you for your letter, I will reply ASAP.BWTianwen Zhou
|
Beta Was this translation helpful? Give feedback.
-
Thank you for your letter, I will reply ASAP.BWTianwen Zhou
|
Beta Was this translation helpful? Give feedback.
-
Hi, I want to extract attention map from pretrained vision transformer for specific image.
How I can do that?
Beta Was this translation helpful? Give feedback.
All reactions