-
Notifications
You must be signed in to change notification settings - Fork 43
/
utils.py
92 lines (69 loc) · 2.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
'''
@Author: Zhou Kai
@GitHub: https://github.com/athon2
@Date: 2018-11-30 09:53:44
'''
import pickle
import torch
import tensorboardX
def pickle_load(in_file):
with open(in_file, "rb") as opened_file:
return pickle.load(opened_file)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Logger(object):
def __init__(self, model_name,header):
self.header = header
self.writer = tensorboardX.SummaryWriter("./runs/"+model_name.split("/")[-1].split(".h5")[0])
def __del(self):
self.writer.close()
def log(self, phase, values):
epoch = values['epoch']
for col in self.header[1:]:
self.writer.add_scalar(phase+"/"+col,float(values[col]),int(epoch))
def load_value_file(file_path):
with open(file_path, 'r') as input_file:
value = float(input_file.read().rstrip('\n\r'))
return value
#def calculate_accuracy(outputs, targets):
# batch_size = targets.size(0)
#
# _, pred = outputs.topk(1, 1, True)
# pred = pred.t()
# correct = pred.eq(targets.view(1, -1))
# n_correct_elems = correct.float().sum().data[0]
#
# return n_correct_elems / batch_size
def calculate_accuracy(outputs, targets):
return dice_coefficient(outputs, targets)
def dice_coefficient(outputs, targets, threshold=0.5, eps=1e-8):
batch_size = targets.size(0)
y_pred = outputs[:,0,:,:,:]
y_truth = targets[:,0,:,:,:]
y_pred = y_pred > threshold
y_pred = y_pred.type(torch.FloatTensor)
intersection = torch.sum(torch.mul(y_pred, y_truth)) + eps/2
union = torch.sum(y_pred) + torch.sum(y_truth) + eps
dice = 2 * intersection / union
return dice / batch_size
def load_old_model(model, optimizer, saved_model_path):
print("Constructing model from saved file... ")
checkpoint = torch.load(saved_model_path)
epoch = checkpoint["epoch"]
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
return model, epoch, optimizer
def normalize_data(data, mean, std):
pass