This repository has been archived by the owner on Jun 9, 2023. It is now read-only.
forked from tse-chunchen/model_error_correction
-
Notifications
You must be signed in to change notification settings - Fork 1
/
check_model.py
350 lines (265 loc) · 13.4 KB
/
check_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import os
import sys
import logging
from time import time
from glob import glob
from joblib import Parallel, delayed
#logging.basicConfig(level=logging.DEBUG)
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
main_dir = "./model_error_correction"
python_exe = "./python"
def int_float_str(s):
'''
convert string to int or float if possible
'''
try:
return int(s)
except ValueError:
try:
return float(s)
except ValueError:
return s
def count_parameters(model):
'''count number of parameters in the NN model'''
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_train_param_name(model_type):
logging.info('################################################')
logging.info('## Get_train_param_name ')
logging.info('################################################')
if model_type == 'conv2d':
from training import _train_ as Train
arg_count = Train.__code__.co_argcount-1 # count training code arguments
arg_names = Train.__code__.co_varnames[1:arg_count+1] # get training code variable names
else:
logging.error(model_type+" not supported!")
return arg_count, arg_names
def get_test_dataset(hyperparam, num_workers=0):
from training import Dataset_np as Dataset
logging.info('################################################')
logging.info('## get_test_dataset ')
logging.info('################################################')
testset = hyperparam['testset']
# define the testing index range
if testset==0:
test_slice = slice(40+1460,None)
#train_valid_slice = slice(40,40+1460)
elif testset==1:
test_slice = slice(40,40+367)
#train_valid_slice = slice(40+368,None)
elif testset==2:
test_slice = slice(None,None) #for sample use
else:
logging.error("rank: {}, testset values {} not supported".format(rank, testset))
exit()
test_set = Dataset(idx_include=test_slice, **hyperparam) # initiate dataset object
test_Loader = DataLoader(test_set, batch_size=len(test_set),num_workers=num_workers) # set up data loader
return test_Loader
def get_train_dataset(hyperparam, num_workers=0):
from training import Dataset_np as Dataset
logging.info('################################################')
logging.info('## get_test_dataset ')
logging.info('################################################')
testset = hyperparam['testset']
# define the training and validation index range
if testset==0:
train_valid_slice = slice(40,40+1460)
elif testset==1:
train_valid_slice = slice(40+368,None)
elif testset==2:
train_valid_slice = slice(None,None) #for sample use
else:
logging.error("rank: {}, testset values {} not supported".format(rank, testset))
exit()
train_valid_set = Dataset(idx_include=train_valid_slice, **hyperparam) # initiate dataset object
train_valid_Loader = DataLoader(train_valid_set, batch_size=hyperparam['bs'], num_workers=num_workers) # set up data loader
return train_valid_Loader
def get_norm(filename):
logging.info('################################################')
logging.info('## get_norm ')
logging.info('################################################')
hyperparam = read_hyperparam(filename) # get hyperparameter from filename
test_Loader = get_test_dataset(hyperparam, 4)
return test_Loader.dataset.mean_out.numpy(), test_Loader.dataset.std_out.numpy()
def read_hyperparam(filename):
logging.info('################################################')
logging.info('## read_hyperparam ')
logging.info('################################################')
logging.info(filename) # full list
name = filename.split('/')[-1] # get rid of parent directory
hyper = [int_float_str(i) for i in name.split('_') ] # get hyperparameters
model_type = hyper[0]
arg_count, arg_names = get_train_param_name(model_type)
if len(hyper) != arg_count+1: # check if the argument counts and filename matches
sys.exit("format incorrect!!: {}".format(name))
hyperparam = dict(zip(('model_type',)+arg_names, hyper))
return hyperparam
def read_checkfile(filename):
'''read checkpoint file for hyperparameters and training status'''
logging.info('################################################')
logging.info('## read_checkfile ')
logging.info('################################################')
hyperparam = read_hyperparam(filename) # get hyperparameter from filename
try:
checkfile = torch.load(filename,map_location=torch.device('cpu')) # load checkpoints from file on cpu
valid_min = min(checkfile['valid_loss'])
epoches = len(checkfile['valid_loss'])
impatience = checkfile['impatience']
add_param = dict(zip(['filename','epoches','impatience','valid_min'],
[filename, epoches, impatience, valid_min]))
hyperparam.update(add_param)
return hyperparam
except (RuntimeError, EOFError):
logging.error("Failed reading: " + filename)
except (ValueError):
logging.error("OOM: " + filename)
def read_model(filename, if_hyperparam=False, if_iosize=False):
'''initialize model from checkpoint file'''
logging.info('################################################')
logging.info('## read_model ')
logging.info('################################################')
hyperparam = read_hyperparam(filename) # get hyperparameter from filename
checkfile = torch.load(filename, map_location=torch.device('cpu')) # load checkpoints from file on cpu
input_size = checkfile['model_state_dict']['convs.0.weight'].shape[1]
output_size = checkfile['model_state_dict']['convs.{}.weight'.format(hyperparam['n_conv']-1)].shape[0]
if hyperparam['model_type'] == 'conv2d':
from model import CONV2D as NN
model = NN(input_size=input_size, output_size=output_size, **hyperparam) # initialize model object
model.load_state_dict(checkfile['model_state_dict']) # load state from checkpoint file into the model object
if if_hyperparam and if_iosize:
return model, hyperparam, (input_size,output_size)
elif if_hyperparam and not if_iosize:
return model, hyperparam
elif not if_hyperparam and if_iosize:
return model, (input_size,output_size)
else:
return model
def eval_model(filename):
'''evaluate model skill on testing data'''
logging.info('################################################')
logging.info('## eval_model ')
logging.info('################################################')
model, hyperparam = read_model(filename, True) # get model
test_Loader = get_test_dataset(hyperparam, 4) # get testing data loader
# name the truth and prediction files
name = filename.split('/')[-1]
y_pred_file = main_dir+'/npys/ypred_'+name+'.npy' # prediction from model
y_file = main_dir+'/npys/y_'+name+'.npy' # truth
logging.info(y_pred_file)
logging.info(y_file)
t0 = time()
# run model through the testing dataset in evaluation mode
with torch.set_grad_enabled(False):
model.eval()
for X, y in test_Loader:
y_pred = model(X)
logging.info('took {}s'.format(time()-t0))
np.save(y_pred_file, y_pred)
np.save(y_file, y)
def sub_eval_model(filename, if_get_norm=False, if_renew=True, if_wait=True):
'''Submit eval_model job to slurm'''
name = filename.split('/')[-1]
y_pred_file = main_dir+'/npys/ypred_'+name+'.npy'
y_file = main_dir+'/npys/y_'+name+'.npy'
if_file_exist = (os.path.isfile(y_pred_file) & os.path.isfile(y_file))
if_submit = not if_file_exist or if_renew # determine if read from previous results or submit job for renewal or for first time
if if_submit:
logging.info("filename")
logging.info('model eval 1st time. will take longer.')
# make tmp.py file for job submission
os.system("echo from check_model import eval_model > ./tmp.py")
os.system('''echo "eval_model(\'{}\')" >> {}/tmp.py'''.format(filename,main_dir))
# put together submit command
prefix = 'sbatch'
if if_wait: # block the current process to wait for the job to complete
prefix = 'sbatch --wait'
submitline = prefix + f' -t 30:0 -A gsienkf -p bigmem -N 1 --output {main_dir}/eval_model.out --wrap "{python_exe} -u {main_dir}/tmp.py " '
os.system(submitline)
else:
logging.info('model evaled before. reading from file.')
if_output = if_wait or not if_submit # output the truth and prediction if waiting for the job to complete or just to read previous results
if if_output:
y_pred = np.load(y_pred_file)
y = np.load(y_file)
logging.info('finished')
logging.info('Learned percentage: {}'.format(1-np.mean((y_pred-y)**2)/np.mean((y)**2)))
logging.info(f'R2: {1-np.mean((y_pred-y)**2)/np.mean((y-y.mean())**2)}')
logging.info(f'MSE: {np.mean((y_pred-y)**2)}')
if if_get_norm & if_output:
mean_out, std_out = get_norm(filename)
return y_pred, y, mean_out, std_out
elif if_output:
return y_pred, y
else:
logging.info('not waiting for sbatch. keep going')
def collect_models():
''' collect the hyperparameters and training status for all the trained models into dataframe'''
logging.info('################################################')
logging.info('## collect_models ')
logging.info('################################################')
t0 = time()
checks = sorted(glob(f'{main_dir}/checks/conv*'))
dicts = Parallel(n_jobs=10, verbose=10)(delayed(read_checkfile)(filename) for filename in checks) # read the checkpoint file in parallel
dicts = [i for i in dicts if i] # get rid of None
df = pd.DataFrame(dicts) # convert to dataframe
df.keys()
df.to_pickle(f'{main_dir}/checks/df_low-res-config') # save the dataframe to file
logging.info('took {} to finish building dataframe'.format(time()-t0))
def sub_collect_models(if_renew=True, df='df_low-res-config', if_wait=True):
''' Submit collect_models to slurm'''
if if_renew:
prefix = '''sbatch'''
if if_wait:
prefix = '''sbatch --wait'''
submitline = prefix+f''' -t 120:0 -A gsienkf -p hera -N 1 --output {main_dir}/collect_models.out --wrap "{python_exe} -u -c 'from check_model import collect_models
collect_models()' " '''
os.system(submitline)
else:
logging.info('reading df from file.')
df = pd.read_pickle('{}/checks/{}'.format(main_dir,df))
return df
def saliency(filename,):
''' Compute averaged gradients in training dataset and Save to file'''
name = filename.split('/')[-1] # get saliency filename from checkpoint filename
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use gpu if available. use cpu if not (significantly slower)
model, hyperparam, iosize = read_model(filename, if_hyperparam=True, if_iosize=True) # get trained model
model.to(device) # put model to device (gpu or cpu)
model.eval() # set model in evaluation mode
hyperparam['bs']=1 # to avoid unmatching batchsize at the end
train_valid_Loader = get_train_dataset(hyperparam, num_workers=8) # set up training data loader
J_tmp = np.zeros(iosize[::-1])
J = np.zeros(iosize[::-1])
for j, (x,y) in enumerate(train_valid_Loader):
logging.info(j)
x = x.to(device)
x.requires_grad=True
y = model(x)
# pytorch was designed to scalar output
# so the backward doesn't provide full Jacobian matrix, but a product of vector^T and J
# need to loop through the output dimension to get the full matrix
for i in range(127):
ext = torch.zeros((hyperparam['bs'],iosize[1],32,64),device=device)
ext[:,i] = 1.
y.backward(gradient=ext, retain_graph=True)
J_tmp[i,:] = x.grad.data.cpu().mean(axis=(0,2,3))
_ = x.grad.data.zero_()
J += J_tmp
J /= j
np.save('{}/npys/J_{}'.format(main_dir,name),J)
def sub_saliency(filename, if_renew=False):
'''Submit saliency to slurm'''
name = filename.split('/')[-1]
J_file = '{}/npys/J_{}.npy'.format(main_dir,name)
if not os.path.isfile(J_file) or if_renew:
logging.info('saliency eval 1st time. will take longer.')
os.system(f"echo from check_model import saliency > {main_dir}/tmp.py")
os.system('''echo "saliency(\'{}\')" >> {}/tmp.py'''.format(filename,main_dir))
submitline = f'sbatch --wait -t 30:0:0 -A rda-ddbcufs -p fge -N 1 --output {main_dir}/eval_saliency.out --wrap "{python_exe} -u {main_dir}/tmp.py " '
os.system(submitline)
else:
logging.info('saliency evaled before. reading from file.')
J = np.load(J_file)
logging.info('finished')
return J