-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval.py
64 lines (47 loc) · 1.51 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import torch
from torch import nn
from config_reader.parser import ConfigParser
from model.get_model import get_model
from reader.reader import init_dataset
from model.work import valid_net
from utils.util import print_info
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c')
parser.add_argument('--gpu', '-g')
parser.add_argument('--model', '-m')
args = parser.parse_args()
configFilePath = args.config
if configFilePath is None:
print("python *.py\t--config/-c\tconfigfile")
use_gpu = True
if args.gpu is None:
use_gpu = False
else:
use_gpu = True
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
config = ConfigParser(configFilePath)
print_info("Start to build Net")
model_name = config.get("model", "name")
net = get_model(model_name, config)
device = []
if torch.cuda.is_available() and use_gpu:
device_list = args.gpu.split(",")
for a in range(0, len(device_list)):
device.append(int(a))
net = net.cuda()
try:
net.init_multi_gpu(device)
except Exception as e:
print_info(str(e))
net.load_state_dict(torch.load(args.model))
print_info("Net build done")
print_info("Start to prepare Data")
train_dataset, valid_dataset = init_dataset(config)
print_info("Data preparation Done")
valid_net(net, valid_dataset, use_gpu, config, 0)
for a in range(0, len(train_dataset.read_process)):
train_dataset.read_process[a].terminate()
for a in range(0, len(valid_dataset.read_process)):
valid_dataset.read_process[a].terminate()