Skip to content

Commit

Permalink
Merge pull request #24 from bowenc0221/boundary_ap
Browse files Browse the repository at this point in the history
Support Boundary AP evaluation
  • Loading branch information
bowenc0221 authored Jun 18, 2021
2 parents 35f09cd + 31ace4a commit 98b5f79
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 15 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ When complete, it will feature more than 2 million high-quality instance segment

<img src="images/examples.png"/>

## LVIS challenge 2021
For this release, we replace the old COCO-style Mask AP with the combination of two new metrics: [Boundary AP](https://arxiv.org/abs/2103.16562) and [Fixed AP](https://arxiv.org/abs/2102.01066). The new metric will be used in the LVIS Challenge to be held at LVIS Workshop at ICCV 2021.

## LVIS v1.0

For this release, we have annotated 159,623 images (100k train, 20k val, 20k test-dev, 20k test-challenge). Release v1.0 is publicly available at [LVIS website](http://www.lvisdataset.org) and will be used in the second LVIS Challenge to be held at Joint COCO and LVIS Workshop at ECCV 2020.
Expand All @@ -19,6 +22,8 @@ source env/bin/activate # Activate virtual environment

# install COCO API. COCO API requires numpy to install. Ensure that you installed numpy.
pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
# install OpenCV (optional if you already have OpenCV installed)
pip install -U opencv-python
# install LVIS API
pip install lvis
# Work for a while ...
Expand All @@ -32,6 +37,8 @@ source env/bin/activate # Activate virtual environment

# install COCO API. COCO API requires numpy to install. Ensure that you installed numpy.
pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
# install OpenCV (optional if you already have OpenCV installed)
pip install -U opencv-python
# install LVIS API
pip install .
# test if the installation was correct
Expand All @@ -50,6 +57,28 @@ If you find this code/data useful in your research then please cite our [paper](
year={2019}
}
```
## Citing Boundary AP

If you find the Boundary AP metric useful in your research then please cite our [paper](https://arxiv.org/abs/2103.16562):
```BibTeX
@inproceedings{cheng2021boundary,
title={Boundary {IoU}: Improving Object-Centric Image Segmentation Evaluation},
author={Bowen Cheng and Ross Girshick and Piotr Doll{\'a}r and Alexander C. Berg and Alexander Kirillov},
booktitle={CVPR},
year={2021}
}
```
## Citing Fixed AP

If you find the Fixed AP metric useful in your research then please cite our [paper](https://arxiv.org/abs/2102.01066):
```BibTeX
@article{dave2021evaluating,
title={Evaluating Large-Vocabulary Object Detectors: The Devil is in the Details},
author={Dave, Achal and Doll{\'a}r, Piotr and Ramanan, Deva and Kirillov, Alexander and Girshick, Ross},
journal={arXiv preprint arXiv:2102.01066},
year={2021}
}
```

## Credit

Expand Down
68 changes: 68 additions & 0 deletions lvis/boundary_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import cv2
import logging
import multiprocessing
import numpy as np

import pycocotools.mask as mask_utils

logger = logging.getLogger(__name__)


# General util function to get the boundary of a binary mask.
def mask_to_boundary(mask, dilation_ratio=0.02):
"""
Convert binary mask to boundary mask.
:param mask (numpy array, uint8): binary mask
:param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
:return: boundary mask (numpy array)
"""
h, w = mask.shape
img_diag = np.sqrt(h ** 2 + w ** 2)
dilation = int(round(dilation_ratio * img_diag))
if dilation < 1:
dilation = 1
# Pad image so mask truncated by the image border is also considered as boundary.
new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
kernel = np.ones((3, 3), dtype=np.uint8)
new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation)
mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1]
# G_d intersects G in the paper.
return mask - mask_erode


# COCO/LVIS related util functions, to get the boundary for every annotations.
def augment_annotations_with_boundary_single_core(proc_id, annotations, ann_to_mask, dilation_ratio=0.02):
new_annotations = []

for ann in annotations:
mask = ann_to_mask(ann)
# Find mask boundary.
boundary = mask_to_boundary(mask, dilation_ratio)
# Add boundary to annotation in RLE format.
ann['boundary'] = mask_utils.encode(
np.array(boundary[:, :, None], order="F", dtype="uint8"))[0]
new_annotations.append(ann)

return new_annotations


def augment_annotations_with_boundary_multi_core(annotations, ann_to_mask, dilation_ratio=0.02):
cpu_num = multiprocessing.cpu_count()
annotations_split = np.array_split(annotations, cpu_num)
logger.info("Number of cores: {}, annotations per core: {}".format(cpu_num, len(annotations_split[0])))
workers = multiprocessing.Pool(processes=cpu_num)
processes = []

for proc_id, annotation_set in enumerate(annotations_split):
p = workers.apply_async(augment_annotations_with_boundary_single_core,
(proc_id, annotation_set, ann_to_mask, dilation_ratio))
processes.append(p)

new_annotations = []
for p in processes:
new_annotations.extend(p.get())

workers.close()
workers.join()

return new_annotations
66 changes: 56 additions & 10 deletions lvis/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@


class LVISEval:
def __init__(self, lvis_gt, lvis_dt, iou_type="segm"):
def __init__(self, lvis_gt, lvis_dt, iou_type="segm", dilation_ratio=0.02):
"""Constructor for LVISEval.
Args:
lvis_gt (LVIS class instance, or str containing path of annotation file)
lvis_dt (LVISResult class instance, or str containing path of result file,
or list of dict)
iou_type (str): segm or bbox evaluation
dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
"""
self.logger = logging.getLogger(__name__)

if iou_type not in ["bbox", "segm"]:
if iou_type not in ["bbox", "segm", "boundary"]:
raise ValueError("iou_type: {} is not supported.".format(iou_type))

self.use_boundary_iou = iou_type == "boundary"

if isinstance(lvis_gt, LVIS):
self.lvis_gt = lvis_gt
elif isinstance(lvis_gt, str):
Expand All @@ -39,6 +42,21 @@ def __init__(self, lvis_gt, lvis_dt, iou_type="segm"):
else:
raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt))

# Precompute boundary.
if self.use_boundary_iou:
if not self.lvis_gt.precompute_boundary:
self.lvis_gt.precompute_boundary = self.use_boundary_iou
self.lvis_gt.dilation_ratio = dilation_ratio
self.lvis_gt._create_index()
else:
assert self.lvis_gt.dilation_ratio == dilation_ratio, "Dilation ratio not consistent"
if not self.lvis_dt.precompute_boundary:
self.lvis_dt.precompute_boundary = self.use_boundary_iou
self.lvis_dt.dilation_ratio = dilation_ratio
self.lvis_dt._create_index()
else:
assert self.lvis_gt.dilation_ratio == dilation_ratio, "Dilation ratio not consistent"

# per-image per-category evaluation results
self.eval_imgs = defaultdict(list)
self.eval = {} # accumulated evaluation results
Expand Down Expand Up @@ -67,8 +85,8 @@ def _prepare(self):
dts = self.lvis_dt.load_anns(
self.lvis_dt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)
)
# convert ground truth to mask if iou_type == 'segm'
if self.params.iou_type == "segm":
# convert ground truth to mask if iou_type == 'segm' or 'boundary'
if self.params.iou_type == "segm" or self.params.iou_type == "boundary":
self._to_mask(gts, self.lvis_gt)
self._to_mask(dts, self.lvis_dt)

Expand Down Expand Up @@ -180,15 +198,43 @@ def compute_iou(self, img_id, cat_id):
ann_type = "segmentation"
elif self.params.iou_type == "bbox":
ann_type = "bbox"
elif self.params.iou_type == "boundary":
# We need to compute both mask and boundary iou
ann_type = None
else:
raise ValueError("Unknown iou_type for iou computation.")
gt = [g[ann_type] for g in gt]
dt = [d[ann_type] for d in dt]

# compute iou between each dt and gt region
# will return array of shape len(dt), len(gt)
ious = mask_utils.iou(dt, gt, iscrowd)
return ious
if ann_type is not None:
gt = [g[ann_type] for g in gt]
dt = [d[ann_type] for d in dt]
# compute iou between each dt and gt region
# will return array of shape len(dt), len(gt)
ious = mask_utils.iou(dt, gt, iscrowd)
return ious
else:
# combine mask and boundary iou
# Mask
gt_m = [g["segmentation"] for g in gt]
dt_m = [d["segmentation"] for d in dt]
# Boundary
gt_b = [g["boundary"] for g in gt]
dt_b = [d["boundary"] for d in dt]
# compute iou between each dt and gt region
# will return array of shape len(dt), len(gt)
mask_ious = mask_utils.iou(dt_m, gt_m, iscrowd)
boundary_ious = mask_utils.iou(dt_b, gt_b, iscrowd)
# combine mask and boundary iou
mask_ious = np.array(mask_ious)
boundary_ious = np.array(boundary_ious)
iscrowd = np.array(iscrowd)
ious = mask_ious
if len(gt) and len(dt):
# keep "mask iou" for crowd annotation
ious[:, iscrowd == 0] = np.minimum(mask_ious[:, iscrowd == 0], boundary_ious[:, iscrowd == 0])
else:
# corner case, one or both sets are empty
ious = np.minimum(mask_ious, boundary_ious)
return ious

def evaluate_img(self, img_id, cat_id, area_rng):
"""Perform evaluation for single category and image."""
Expand Down
26 changes: 22 additions & 4 deletions lvis/lvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,29 @@
import logging
from collections import defaultdict
from urllib.request import urlretrieve
import time

import pycocotools.mask as mask_utils

from lvis.boundary_utils import augment_annotations_with_boundary_multi_core


class LVIS:
def __init__(self, annotation_path):
def __init__(self, annotation_path, precompute_boundary=False, dilation_ratio=0.02):
"""Class for reading and visualizing annotations.
Args:
annotation_path (str): location of annotation file
precompute_boundary (bool): whether to precompute mask boundary before evaluation
dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
"""
self.logger = logging.getLogger(__name__)
self.logger.info("Loading annotations.")

self.dataset = self._load_json(annotation_path)

self.precompute_boundary = precompute_boundary
self.dilation_ratio = dilation_ratio

assert (
type(self.dataset) == dict
), "Annotation file format {} not supported.".format(type(self.dataset))
Expand All @@ -45,13 +53,23 @@ def _create_index(self):
self.cats = {}
self.imgs = {}

for img in self.dataset["images"]:
self.imgs[img["id"]] = img

if self.precompute_boundary:
# add `boundary` to annotation
self.logger.info('Adding `boundary` to annotation.')
tic = time.time()
self.dataset["annotations"] = augment_annotations_with_boundary_multi_core(self.dataset["annotations"],
self.ann_to_mask,
dilation_ratio=self.dilation_ratio)

self.logger.info('`boundary` added! (t={:0.2f}s)'.format(time.time()- tic))

for ann in self.dataset["annotations"]:
self.img_ann_map[ann["image_id"]].append(ann)
self.anns[ann["id"]] = ann

for img in self.dataset["images"]:
self.imgs[img["id"]] = img

for cat in self.dataset["categories"]:
self.cats[cat["id"]] = cat

Expand Down
8 changes: 7 additions & 1 deletion lvis/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,27 @@


class LVISResults(LVIS):
def __init__(self, lvis_gt, results, max_dets=300):
def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, dilation_ratio=0.02):
"""Constructor for LVIS results.
Args:
lvis_gt (LVIS class instance, or str containing path of
annotation file)
results (str containing path of result file or a list of dicts)
max_dets (int): max number of detections per image. The official
value of max_dets for LVIS is 300.
precompute_boundary (bool): whether to precompute mask boundary before evaluation
dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
"""
if isinstance(lvis_gt, LVIS):
self.dataset = deepcopy(lvis_gt.dataset)
precompute_boundary = lvis_gt.precompute_boundary
elif isinstance(lvis_gt, str):
self.dataset = self._load_json(lvis_gt)
else:
raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt))

self.precompute_boundary = precompute_boundary
self.dilation_ratio = dilation_ratio

self.logger = logging.getLogger(__name__)
self.logger.info("Loading and preparing results.")
Expand Down

0 comments on commit 98b5f79

Please sign in to comment.