Skip to content

Commit

Permalink
Added a comparative study of the speed of the library
Browse files Browse the repository at this point in the history
  • Loading branch information
MiXaiLL76 committed Jun 16, 2022
1 parent b3030b2 commit 867cda4
Show file tree
Hide file tree
Showing 6 changed files with 1,025 additions and 10 deletions.
28 changes: 18 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,27 @@ for coco's AP metrics, especially when dealing with a high number of instances i

### Comparison

For our use case with a test dataset of 1500 images that contains up to 2000 instances per image we saw up to a 100x faster
evaluation using faster-coco-eval (FCE) compared to the original pycocotools code.
````
Seg eval pycocotools 4 hours
Seg eval FCE: 2.5 min
For our use case with a test dataset of 5000 images from the coco val dataset.
Testing was carried out using the mmdetection framework and the eval_metric.py script. The indicators are presented below.

Visualization of testing **comparison.ipynb** available in directory [examples/comparison](./examples/comparison/comparison.ipynb)
Tested with yolo3 model (bbox eval) and yoloact model (segm eval)

Type | COCOeval | COCOeval_faster | Profit
-----|----------|---------------- | ------
bbox | 22.854 sec. | 8.714 sec. | more than 2x
segm | 35.356 sec. | 18.403 sec. | 2x

BBox eval pycocotools: 4 hours
BBox eval FCE: 2 min
````

# Getting started

### Install
## Local build
Build from source
```bash
python3 setup.py sdist
```

## Install
Install form source
```bash
pip3 install git+https://github.com/MiXaiLL76/faster_coco_eval
Expand Down Expand Up @@ -55,7 +63,7 @@ For usage, look at the original `COCOEval` [class documentation.](https://github
- [x] Append unittest
- [x] Append ROC / AUC curves
- [x] Check if it works on windows
- [ ] Append fp fn error analysis
- [ ] Append fp fn error analysis

# License
The original module was licensed with apache 2, I will continue with the same license.
Expand Down
6 changes: 6 additions & 0 deletions examples/comparison/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Visualization of testing [comparison.ipynb](./comparison.ipynb)

Type | COCOeval | COCOeval_faster | Profit
-----|----------|---------------- | ------
bbox | 22.854 sec. | 8.714 sec. | more than 2x
segm | 35.356 sec. | 18.403 sec. | 2x
225 changes: 225 additions & 0 deletions examples/comparison/coco_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) OpenMMLab. All rights reserved.
import contextlib
import io
import itertools
import logging
import os.path as osp
import tempfile
import warnings
from collections import OrderedDict

import mmcv
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable

from mmdet.core import eval_recalls
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset
from faster_coco_eval import COCOeval_faster

@DATASETS.register_module()
class FasterCocoDataset(CocoDataset):
def evaluate_det_segm(self,
results,
result_files,
coco_gt,
metrics,
logger=None,
classwise=False,
proposal_nums=(100, 300, 1000),
iou_thrs=None,
metric_items=None):
"""Instance segmentation and object detection evaluation in COCO
protocol.
Args:
results (list[list | tuple | dict]): Testing results of the
dataset.
result_files (dict[str, str]): a dict contains json file path.
coco_gt (COCO): COCO API object with ground truth annotation.
metric (str | list[str]): Metrics to be evaluated. Options are
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
classwise (bool): Whether to evaluating the AP for each class.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thrs (Sequence[float], optional): IoU threshold used for
evaluating recalls/mAPs. If set to a list, the average of all
IoUs will also be computed. If not specified, [0.50, 0.55,
0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
Default: None.
metric_items (list[str] | str, optional): Metric items that will
be returned. If not specified, ``['AR@100', 'AR@300',
'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
``metric=='bbox' or metric=='segm'``.
Returns:
dict[str, float]: COCO style evaluation metric.
"""
if iou_thrs is None:
iou_thrs = np.linspace(
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
if metric_items is not None:
if not isinstance(metric_items, list):
metric_items = [metric_items]

eval_results = OrderedDict()
for metric in metrics:
msg = f'Evaluating {metric}...'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)

if metric == 'proposal_fast':
if isinstance(results[0], tuple):
raise KeyError('proposal_fast is not supported for '
'instance segmentation result.')
ar = self.fast_eval_recall(
results, proposal_nums, iou_thrs, logger='silent')
log_msg = []
for i, num in enumerate(proposal_nums):
eval_results[f'AR@{num}'] = ar[i]
log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
log_msg = ''.join(log_msg)
print_log(log_msg, logger=logger)
continue

iou_type = 'bbox' if metric == 'proposal' else metric
if metric not in result_files:
raise KeyError(f'{metric} is not in results')
try:
predictions = mmcv.load(result_files[metric])
if iou_type == 'segm':
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
# When evaluating mask AP, if the results contain bbox,
# cocoapi will use the box area instead of the mask area
# for calculating the instance area. Though the overall AP
# is not affected, this leads to different
# small/medium/large mask AP results.
for x in predictions:
x.pop('bbox')
warnings.simplefilter('once')
warnings.warn(
'The key "bbox" is deleted for more accurate mask AP '
'of small/medium/large instances since v2.12.0. This '
'does not change the overall mAP calculation.',
UserWarning)
coco_det = coco_gt.loadRes(predictions)
except IndexError:
print_log(
'The testing results of the whole dataset is empty.',
logger=logger,
level=logging.ERROR)
break

cocoEval = COCOeval_faster(coco_gt, coco_det, iou_type)
cocoEval.params.catIds = self.cat_ids
cocoEval.params.imgIds = self.img_ids
cocoEval.params.maxDets = list(proposal_nums)
cocoEval.params.iouThrs = iou_thrs
# mapping of cocoEval.stats
coco_metric_names = {
'mAP': 0,
'mAP_50': 1,
'mAP_75': 2,
'mAP_s': 3,
'mAP_m': 4,
'mAP_l': 5,
'AR@100': 6,
'AR@300': 7,
'AR@1000': 8,
'AR_s@1000': 9,
'AR_m@1000': 10,
'AR_l@1000': 11
}
if metric_items is not None:
for metric_item in metric_items:
if metric_item not in coco_metric_names:
raise KeyError(
f'metric item {metric_item} is not supported')

if metric == 'proposal':
cocoEval.params.useCats = 0
cocoEval.evaluate()
cocoEval.accumulate()

# Save coco summarize print information to logger
redirect_string = io.StringIO()
with contextlib.redirect_stdout(redirect_string):
cocoEval.summarize()
print_log('\n' + redirect_string.getvalue(), logger=logger)

if metric_items is None:
metric_items = [
'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
'AR_m@1000', 'AR_l@1000'
]

for item in metric_items:
val = float(
f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
eval_results[item] = val
else:
cocoEval.evaluate()
cocoEval.accumulate()

# Save coco summarize print information to logger
redirect_string = io.StringIO()
with contextlib.redirect_stdout(redirect_string):
cocoEval.summarize()
print_log('\n' + redirect_string.getvalue(), logger=logger)

if classwise: # Compute per-category AP
# Compute per-category AP
# from https://github.com/facebookresearch/detectron2/
precisions = cocoEval.eval['precision']
# precision: (iou, recall, cls, area range, max dets)
assert len(self.cat_ids) == precisions.shape[2]

results_per_category = []
for idx, catId in enumerate(self.cat_ids):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
nm = self.coco.loadCats(catId)[0]
precision = precisions[:, :, idx, 0, -1]
precision = precision[precision > -1]
if precision.size:
ap = np.mean(precision)
else:
ap = float('nan')
results_per_category.append(
(f'{nm["name"]}', f'{float(ap):0.3f}'))

num_columns = min(6, len(results_per_category) * 2)
results_flatten = list(
itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(*[
results_flatten[i::num_columns]
for i in range(num_columns)
])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
print_log('\n' + table.table, logger=logger)

if metric_items is None:
metric_items = [
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
]

for metric_item in metric_items:
key = f'{metric}_{metric_item}'
val = float(
f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
)
eval_results[key] = val
ap = cocoEval.stats[:6]
eval_results[f'{metric}_mAP_copypaste'] = (
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
f'{ap[4]:.3f} {ap[5]:.3f}')

return eval_results
Loading

0 comments on commit 867cda4

Please sign in to comment.