-
Notifications
You must be signed in to change notification settings - Fork 143
/
eval_ggcnn.py
111 lines (87 loc) · 4.85 KB
/
eval_ggcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import logging
import torch.utils.data
from models.common import post_process_output
from utils.dataset_processing import evaluation, grasp
from utils.data import get_dataset
logging.basicConfig(level=logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate GG-CNN')
# Network
parser.add_argument('--network', type=str, help='Path to saved network to evaluate')
# Dataset & Data & Training
parser.add_argument('--dataset', type=str, help='Dataset Name ("cornell" or "jaquard")')
parser.add_argument('--dataset-path', type=str, help='Path to dataset')
parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for evaluation (1/0)')
parser.add_argument('--use-rgb', type=int, default=0, help='Use RGB image for evaluation (0/1)')
parser.add_argument('--augment', action='store_true', help='Whether data augmentation should be applied')
parser.add_argument('--split', type=float, default=0.9, help='Fraction of data for training (remainder is validation)')
parser.add_argument('--ds-rotate', type=float, default=0.0,
help='Shift the start point of the dataset to use a different test/train split')
parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers')
parser.add_argument('--n-grasps', type=int, default=1, help='Number of grasps to consider per image')
parser.add_argument('--iou-eval', action='store_true', help='Compute success based on IoU metric.')
parser.add_argument('--jacquard-output', action='store_true', help='Jacquard-dataset style output')
parser.add_argument('--vis', action='store_true', help='Visualise the network output')
args = parser.parse_args()
if args.jacquard_output and args.dataset != 'jacquard':
raise ValueError('--jacquard-output can only be used with the --dataset jacquard option.')
if args.jacquard_output and args.augment:
raise ValueError('--jacquard-output can not be used with data augmentation.')
return args
if __name__ == '__main__':
args = parse_args()
# Load Network
net = torch.load(args.network)
device = torch.device("cuda:0")
# Load Dataset
logging.info('Loading {} Dataset...'.format(args.dataset.title()))
Dataset = get_dataset(args.dataset)
test_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate,
random_rotate=args.augment, random_zoom=args.augment,
include_depth=args.use_depth, include_rgb=args.use_rgb)
test_data = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers
)
logging.info('Done')
results = {'correct': 0, 'failed': 0}
if args.jacquard_output:
jo_fn = args.network + '_jacquard_output.txt'
with open(jo_fn, 'w') as f:
pass
with torch.no_grad():
for idx, (x, y, didx, rot, zoom) in enumerate(test_data):
logging.info('Processing {}/{}'.format(idx+1, len(test_data)))
xc = x.to(device)
yc = [yi.to(device) for yi in y]
lossd = net.compute_loss(xc, yc)
q_img, ang_img, width_img = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
lossd['pred']['sin'], lossd['pred']['width'])
if args.iou_eval:
s = evaluation.calculate_iou_match(q_img, ang_img, test_data.dataset.get_gtbb(didx, rot, zoom),
no_grasps=args.n_grasps,
grasp_width=width_img,
)
if s:
results['correct'] += 1
else:
results['failed'] += 1
if args.jacquard_output:
grasps = grasp.detect_grasps(q_img, ang_img, width_img=width_img, no_grasps=1)
with open(jo_fn, 'a') as f:
for g in grasps:
f.write(test_data.dataset.get_jname(didx) + '\n')
f.write(g.to_jacquard(scale=1024 / 300) + '\n')
if args.vis:
evaluation.plot_output(test_data.dataset.get_rgb(didx, rot, zoom, normalise=False),
test_data.dataset.get_depth(didx, rot, zoom), q_img,
ang_img, no_grasps=args.n_grasps, grasp_width_img=width_img)
if args.iou_eval:
logging.info('IOU Results: %d/%d = %f' % (results['correct'],
results['correct'] + results['failed'],
results['correct'] / (results['correct'] + results['failed'])))
if args.jacquard_output:
logging.info('Jacquard output saved to {}'.format(jo_fn))