-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
110 lines (92 loc) · 4.28 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import datetime
import json
import os
import warnings
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from lib.dataloader import TestDataset
from lib.model import configure_model
from lib.utils import get_transform
from lib.visual import Visual
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluation of Faster R-CNN.')
parser.add_argument('--model-checkpoint', type=str,
help='Path to pretrained model.')
parser.add_argument('--dataset', default='../data-faster',
type=str, help='Path to dataset.')
parser.add_argument('--output', default='tests',
type=str, help='Path to output the test results (default: tests).')
parser.add_argument('--batch-size', default=4, type=int,
help='Batch size (default: 4).')
parser.add_argument('--img-size', default=1280, type=int,
help='Image size (default: 1280).')
parser.add_argument('--num-workers', default=8, type=int, metavar='N',
help='Number of data loading workers (default: 8).')
parser.add_argument('--conf-threshold', default=0.1, type=float,
help='Confidence threshold in prediction (default: 0.5).')
parser.add_argument('--no-visual', default=True, action='store_true',
help='Disable visualization software in test mode.')
parser.add_argument('--no-save', default=False, action='store_true',
help='Disable results export software.')
args = parser.parse_args()
if not Path(args.model_checkpoint).is_file():
raise ValueError(
f"Path to pretrained model weights is invalid. Value parsed {args.model_checkpoint}.")
if not Path(args.dataset).joinpath("test").is_dir():
raise ValueError(
f"Path to dataset is invalid. Value parsed {Path(args.dataset).joinpath('test')}.")
test_results_dir = None
if not args.no_save:
datetime_tag = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
test_results_dir = Path(args.model_checkpoint).parent.absolute(
).parent.absolute() / Path(args.output) / datetime_tag
test_results_dir.mkdir(parents=True, exist_ok=True)
# initialize the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device utilized:\t[{device}]\n")
with open(Path(args.model_checkpoint).parent.absolute().parent.absolute().joinpath('CONFIG.json')) as f:
data = json.load(f)
args.backbone = data['model']['backbone']
args.anchor_sizes = data['model']['anchors']
args.aspect_ratios = data['model']['ratios']
args.epochs = data['model']['epochs']
args.num_classes = data['dataset']['classes']
args.img_size = data['dataset']['img_size']
# Typecast list of `str` to list of `float`
args.anchor_sizes = [float(anchor) for anchor in args.anchor_sizes]
args.aspect_ratios = [float(ratio) for ratio in args.aspect_ratios]
# model init
model = configure_model(
backbone_name=args.backbone,
anchor_sizes=args.anchor_sizes,
aspect_ratios=args.aspect_ratios,
num_classes=args.num_classes
)
# test dataset
test_data = TestDataset(
root_dir=os.path.join(args.dataset, "test"),
transforms=get_transform(
transform_class="test",
img_size=args.img_size)
)
# test dataloader
dataloader_test = DataLoader(
dataset=test_data,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True
)
# load model to device
model = model.to(device)
# load checkpoint
checkpoint = torch.load(args.model_checkpoint, map_location=device)
model.load_state_dict(checkpoint['model'])
print(f"Loaded model checkpoint at {args.model_checkpoint} successfully")
# test model
visualize = Visual(model=model, root_dir=args.dataset,
device=device, conf_threshold=args.conf_threshold)
visualize.test_model(dataloader=dataloader_test, results_dir=test_results_dir,
no_visual=args.no_visual, no_save=args.no_save)