diff --git a/libs/visualization/pil_utils.py b/libs/visualization/pil_utils.py index e7665af..b5a0c0b 100644 --- a/libs/visualization/pil_utils.py +++ b/libs/visualization/pil_utils.py @@ -50,8 +50,7 @@ def draw_bbox(step, image, name='', image_height=1, image_width=1, bbox=None, la return source_img.save(FLAGS.train_dir + '/est_imgs/test_' + name + '_' + str(step) +'.jpg', 'JPEG') -def cat_id_to_cls_name(catId): - cls_name = np.array([ 'background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', +_CLS_NAMES = np.array(['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', @@ -65,4 +64,7 @@ def cat_id_to_cls_name(catId): 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']) - return cls_name[catId] \ No newline at end of file + +def cat_id_to_cls_name(catId): + global _CLS_NAMES + return _CLS_NAMES[catId]