-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
297 lines (233 loc) · 10.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import os
import random
import logging
import numpy as np
from tqdm import tqdm
from PIL import Image
from os.path import join
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from models.mvtec_model import MVTec_Encoder
def get_out_dir(args, pretrain: bool, aelr: float, dset_name: str="cifar10", training_strategy: str=None) -> [str, str]:
"""Creates training output dir
Parameters
----------
args :
Arguments
pretrain : bool
True if pretrain the model
aelr : float
Full AutoEncoder learning rate
dset_name : str
Dataset name
training_strategy : str
................................................................
Returns
-------
out_dir : str
Path to output folder
tmp : str
String containing infos about the current experiment setup
"""
if dset_name == "ShanghaiTech":
if pretrain:
tmp = (f"pretrain-mn_{dset_name}-cl_{args.code_length}-lr_{args.ae_learning_rate}")
out_dir = os.path.join(args.output_path, dset_name, 'pretrain', tmp)
else:
tmp = (
f"train-mn_{dset_name}-cl_{args.code_length}-bs_{args.batch_size}-nu_{args.nu}-lr_{args.learning_rate}-"
f"bd_{args.boundary}-sl_{args.use_selectors}-ile_{'.'.join(map(str, args.idx_list_enc))}-lstm_{args.load_lstm}-"
f"bidir_{args.bidirectional}-hs_{args.hidden_size}-nl_{args.num_layers}-dp_{args.dropout}"
)
out_dir = os.path.join(args.output_path, dset_name, 'train', tmp)
if args.end_to_end_training:
out_dir = os.path.join(args.output_path, dset_name, 'train_end_to_end', tmp)
else:
if pretrain:
tmp = (f"pretrain-mn_{dset_name}-nc_{args.normal_class}-cl_{args.code_length}-lr_{args.ae_learning_rate}-awd_{args.ae_weight_decay}")
out_dir = os.path.join(args.output_path, dset_name, str(args.normal_class), 'pretrain', tmp)
else:
tmp = (
f"train-mn_{dset_name}-nc_{args.normal_class}-cl_{args.code_length}-bs_{args.batch_size}-nu_{args.nu}-lr_{args.learning_rate}-"
f"wd_{args.weight_decay}-bd_{args.boundary}-alr_{aelr}-sl_{args.use_selectors}-ep_{args.epochs}-ile_{'.'.join(map(str, args.idx_list_enc))}"
)
out_dir = os.path.join(args.output_path, dset_name, str(args.normal_class), 'train', tmp)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
return out_dir, tmp
def set_seeds(seed: int) -> None:
"""Set all seeds.
Parameters
----------
seed : int
Seed
"""
# Set the seed only if the user specified it
if seed != -1:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def purge_params(encoder_net, ae_net_cehckpoint: str) -> None:
"""Load Encoder preatrained weights from the full AutoEncoder.
After the pretraining phase, we don't need the full AutoEncoder parameters, we only need the Encoder
Parameters
----------
encoder_net :
The Encoder network
ae_net_cehckpoint : str
Path to full AutoEncoder checkpoint
"""
# Load the full AutoEncoder checkpoint dict
ae_net_dict = torch.load(ae_net_cehckpoint, map_location=lambda storage, loc: storage)['ae_state_dict']
# Load encoder weight from autoencoder
net_dict = encoder_net.state_dict()
# Filter out decoder network keys
st_dict = {k: v for k, v in ae_net_dict.items() if k in net_dict}
# Overwrite values in the existing state_dict
net_dict.update(st_dict)
# Load the new state_dict
encoder_net.load_state_dict(net_dict)
def load_mvtec_model_from_checkpoint(input_shape: tuple, code_length: int, idx_list_enc: list, use_selectors: bool, net_cehckpoint: str, purge_ae_params: bool = False) -> torch.nn.Module:
"""Load AutoEncoder checkpoint.
Parameters
----------
input_shape : tuple
Input data shape
code_length : int
Latent code size
idx_list_enc : list
List of indexes of layers from which extract features
use_selectors : bool
True if the model has to use Selector modules
net_cehckpoint : str
Path to model checkpoint
purge_ae_params : bool
True if the checkpoint is relative to an AutoEncoder
Returns
-------
encoder_net : torch.nn.Module
The Encoder network
"""
logger = logging.getLogger()
encoder_net = MVTec_Encoder(
input_shape=input_shape,
code_length=code_length,
idx_list_enc=idx_list_enc,
use_selectors=use_selectors
)
if purge_ae_params:
# Load Encoder parameters from pretrianed full AutoEncoder
logger.info(f"Loading encoder from: {net_cehckpoint}")
purge_params(encoder_net=encoder_net, ae_net_cehckpoint=net_cehckpoint)
else:
st_dict = torch.load(net_cehckpoint)
encoder_net.load_state_dict(st_dict['net_state_dict'])
logger.info(f"Loaded model from: {net_cehckpoint}")
return encoder_net
def extract_arguments_from_checkpoint(net_checkpoint: str):
"""Takes file path of the checkpoint and parse the checkpoint name to extract training parameters and
architectural specifications of the model.
Parameters
----------
net_checkpoint : file path of the checkpoint (str)
Returns
-------
code_length = latent code size (int)
batch_size = batch_size (int)
boundary = soft or hard boundary (str)
use_selectors = if selectors used it is true, otherwise false (bool)
idx_list_enc = indexes of the exploited layers (list of integers)
load_lstm = boolean to show whether lstm used (bool)
hidden_size = hidden size of the lstm (int)
num_layers = number of layers of the lstm (int)
dropout = dropout probability (float)
bidirectional = is lstm bi-directional or not (bool)
dataset_name = name of the dataset (str)
train_type = is it end-to-end, train, or pretrain (str)
"""
code_length = int(net_checkpoint.split(os.sep)[-2].split('-')[2].split('_')[-1])
batch_size = int(net_checkpoint.split(os.sep)[-2].split('-')[3].split('_')[-1])
boundary = net_checkpoint.split(os.sep)[-2].split('-')[6].split('_')[-1]
use_selectors = net_checkpoint.split(os.sep)[-2].split('-')[7].split('_')[-1] == "True"
idx_list_enc = [int(i) for i in net_checkpoint.split(os.sep)[-2].split('-')[8].split('_')[-1].split('.')]
load_lstm = net_checkpoint.split(os.sep)[-2].split('-')[9].split('_')[-1] == "True"
hidden_size = int(net_checkpoint.split(os.sep)[-2].split('-')[11].split('_')[-1])
num_layers = int(net_checkpoint.split(os.sep)[-2].split('-')[12].split('_')[-1])
dropout = float(net_checkpoint.split(os.sep)[-2].split('-')[13].split('_')[-1])
bidirectional = net_checkpoint.split(os.sep)[-2].split('-')[10].split('_')[-1] == "True"
dataset_name = net_checkpoint.split(os.sep)[-4]
train_type = net_checkpoint.split(os.sep)[-3]
return code_length, batch_size, boundary, use_selectors, idx_list_enc, load_lstm, hidden_size, num_layers, dropout, bidirectional, dataset_name, train_type
def eval_spheres_centers(train_loader: DataLoader, encoder_net: torch.nn.Module, ae_net_cehckpoint: str, use_selectors: bool, device:str, debug: bool) -> dict:
"""Eval the centers of the hyperspheres at each chosen layer.
Parameters
----------
train_loader : DataLoader
DataLoader for trainin data
encoder_net : torch.nn.Module
Encoder network
ae_net_cehckpoint : str
Checkpoint of the full AutoEncoder
use_selectors : bool
True if we want to use selector models
device : str
Device on which run the computations
debug : bool
Activate debug mode
Returns
-------
dict : dictionary
Dictionary with k='layer name'; v='features vector representing hypersphere center'
"""
logger = logging.getLogger()
centers_files = ae_net_cehckpoint[:-4]+f'_w_centers_{use_selectors}.pth'
# If centers are found, then load and return
if os.path.exists(centers_files):
logger.info("Found hyperspheres centers")
ae_net_ckp = torch.load(centers_files, map_location=lambda storage, loc: storage)
centers = {k: v.to(device) for k, v in ae_net_ckp['centers'].items()}
else:
logger.info("Hyperspheres centers not found... evaluating...")
centers_ = init_center_c(train_loader=train_loader, encoder_net=encoder_net, device=device, debug=debug)
logger.info("Hyperspheres centers evaluated!!!")
new_ckp = ae_net_cehckpoint.split('.pth')[0]+f'_w_centers_{use_selectors}.pth'
logger.info(f"New AE dict saved at: {new_ckp}!!!")
centers = {k: v for k, v in centers_.items()}
torch.save({
'ae_state_dict': torch.load(ae_net_cehckpoint)['ae_state_dict'],
'centers': centers
}, new_ckp)
return centers
@torch.no_grad()
def init_center_c(train_loader: DataLoader, encoder_net: torch.nn.Module, device: str, debug: bool, eps: float=0.1) -> dict:
"""Initialize hypersphere center as the mean from an initial forward pass on the data.
Parameters
----------
train_loader :
encoder_net :
debug :
eps:
Returns
-------
dictionary : dict
Dictionary with k='layer name'; v='center featrues'
"""
n_samples = 0
encoder_net.eval().to(device)
for idx, (data, _) in enumerate(tqdm(train_loader, desc='Init hyperspheres centeres', total=len(train_loader), leave=False)):
if debug and idx == 5: break
data = data.to(device)
n_samples += data.shape[0]
zipped = encoder_net(data)
if idx == 0:
c = {item[0]: torch.zeros_like(item[1][-1], device=device) for item in zipped}
for item in zipped:
c[item[0]] += torch.sum(item[1], dim=0)
for k in c.keys():
c[k] = c[k] / n_samples
# If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
c[k][(abs(c[k]) < eps) & (c[k] < 0)] = -eps
c[k][(abs(c[k]) < eps) & (c[k] > 0)] = eps
return c