-
Notifications
You must be signed in to change notification settings - Fork 54
/
feature_extract.py
77 lines (65 loc) · 2.13 KB
/
feature_extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#coding:utf8
'''
利用resnet50提取图片的语义信息
并保存层results.pth
'''
from config import Config
import tqdm
import torch as t
from torch.autograd import Variable
import torchvision as tv
from torch.utils import data
import os
from PIL import Image
import numpy as np
opt = Config()
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
normalize = tv.transforms.Normalize(mean=IMAGENET_MEAN,std=IMAGENET_STD)
class CaptionDataset(data.Dataset):
def __init__(self,caption_data_path):
self.transforms = tv.transforms.Compose([
tv.transforms.Scale(256),
tv.transforms.CenterCrop(256),
tv.transforms.ToTensor(),
normalize
])
data = t.load(caption_data_path)
self.ix2id = data['ix2id']
# 所有图片的路径
self.imgs = [os.path.join(opt.img_path,self.ix2id[_]) \
for _ in range(len(self.ix2id))]
def __getitem__(self,index):
img = Image.open(self.imgs[index]).convert('RGB')
img = self.transforms(img)
return img,index
def __len__(self):
return len(self.imgs)
def get_dataloader(opt):
dataset = CaptionDataset(opt.caption_data_path)
dataloader = data.DataLoader(dataset,
batch_size=opt.batch_size,
shuffle=False,
num_workers=opt.num_workers,
)
return dataloader
# 数据
opt.batch_size=256
dataloader = get_dataloader(opt)
results = t.Tensor(len(dataloader.dataset),2048).fill_(0)
batch_size = opt.batch_size
# 模型
resnet50 = tv.models.resnet50(pretrained=True)
del resnet50.fc
resnet50.fc = lambda x:x
resnet50.cuda()
# 前向传播,计算分数
for ii,(imgs, indexs) in tqdm.tqdm(enumerate(dataloader)):
# 确保序号没有对应错
assert indexs[0]==batch_size*ii
imgs = imgs.cuda()
imgs = Variable(imgs,volatile=True)
features = resnet50(imgs)
results[ii*batch_size:(ii+1)*batch_size]= features.data.cpu()
# 200000*2048 20万张图片,每张图片2048维的feature
t.save(results,'results.pth')