Skip to content

Commit

Permalink
Implement Fixed AP with Boundary IoU (#25)
Browse files Browse the repository at this point in the history
* Implement Fixed AP
* Implement mode="challenge2021" in LVISEval
  • Loading branch information
achalddave authored Jun 23, 2021
1 parent 98b5f79 commit 84ac272
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 16 deletions.
65 changes: 55 additions & 10 deletions lvis/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,29 @@


class LVISEval:
def __init__(self, lvis_gt, lvis_dt, iou_type="segm", dilation_ratio=0.02):
def __init__(self, lvis_gt, lvis_dt, iou_type="segm", mode="default", dilation_ratio=0.02):
"""Constructor for LVISEval.
Args:
lvis_gt (LVIS class instance, or str containing path of annotation file)
lvis_dt (LVISResult class instance, or str containing path of result file,
or list of dict)
iou_type (str): segm or bbox evaluation
iou_type (str): segm, bbox, or boundary evaluation. Ignored if `mode` is set to
'challenge2021'.
dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
mode (str): Either 'default' or 'challenge2021'. Specifying 'challenge2021'
uses iou_type=boundary and limits detections to 10,000 per class
(instead of 300 per image).
"""
self.logger = logging.getLogger(__name__)

if iou_type not in ["bbox", "segm", "boundary"]:
raise ValueError("iou_type: {} is not supported.".format(iou_type))

if mode == "challenge2021":
iou_type = "boundary"
elif mode != "default":
raise ValueError("Unexpected mode: {}".format(mode))

self.use_boundary_iou = iou_type == "boundary"

if isinstance(lvis_gt, LVIS):
Expand All @@ -37,8 +46,25 @@ def __init__(self, lvis_gt, lvis_dt, iou_type="segm", dilation_ratio=0.02):

if isinstance(lvis_dt, LVISResults):
self.lvis_dt = lvis_dt
if mode == "challenge2021":
assert self.lvis_dt.max_dets_per_im == -1, (
"mode='challenge2021' specified with LVISResults object "
"containing incorrect max_dets_per_im (must be -1)."
)
assert self.lvis_dt.max_dets_per_cat == 10000, (
"mode='challenge2021' specified with LVISResults object "
"containing incorrect max_dets_per_cat (must be 10000)."
)
elif isinstance(lvis_dt, (str, list)):
self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt)
if mode == "default":
self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt)
elif mode == "challenge2021":
self.lvis_dt = LVISResults(
self.lvis_gt,
lvis_dt,
max_dets_per_cat=10000,
max_dets_per_im=-1
)
else:
raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt))

Expand All @@ -65,6 +91,9 @@ def __init__(self, lvis_gt, lvis_dt, iou_type="segm", dilation_ratio=0.02):
self.params = Params(iou_type=iou_type) # parameters
self.results = OrderedDict()
self.ious = {} # ious between all gts and dts
if mode == "challenge2021":
self.params.max_dets = -1
self.params.max_dets_per_cat = 10000

self.params.img_ids = sorted(self.lvis_gt.get_img_ids())
self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids())
Expand Down Expand Up @@ -495,7 +524,8 @@ def summarize(self):
if not self.eval:
raise RuntimeError("Please run accumulate() first.")

max_dets = self.params.max_dets
max_dets_im = self.params.max_dets
max_dets_cat = self.params.max_dets_per_cat

self.results["AP"] = self._summarize('ap')
self.results["AP50"] = self._summarize('ap', iou_thr=0.50)
Expand All @@ -507,11 +537,18 @@ def summarize(self):
self.results["APc"] = self._summarize('ap', freq_group_idx=1)
self.results["APf"] = self._summarize('ap', freq_group_idx=2)

key = "AR@{}".format(max_dets)
# If max dets/cat is specified, update key to include both dets/im and dets/cat.
if max_dets_cat < 0:
key_suffix = max_dets_im
elif max_dets_im < 0:
key_suffix = "{}/cat".format(max_dets_cat)
else: # Both max dets/im and max dets/cat specified
key_suffix = "{}/im,{}/cat".format(max_dets_im, max_dets_cat)
key = "AR@{}".format(key_suffix)
self.results[key] = self._summarize('ar')

for area_rng in ["small", "medium", "large"]:
key = "AR{}@{}".format(area_rng[0], max_dets)
key = "AR{}@{}".format(area_rng[0], key_suffix)
self.results[key] = self._summarize('ar', area_rng=area_rng)

def run(self):
Expand All @@ -521,10 +558,17 @@ def run(self):
self.summarize()

def print_results(self):
template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}"
template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={} catIds={:>3s}] = {:0.3f}"

for key, value in self.results.items():
max_dets = self.params.max_dets
max_dets_im = self.params.max_dets
max_dets_cat = self.params.max_dets_per_cat
if max_dets_cat < 0:
max_dets_str = "{:>3d}".format(max_dets_im)
elif max_dets_im < 0:
max_dets_str = "{}/cat".format(max_dets_cat)
else:
max_dets_str = "{}/im,{}/cat".format(max_dets_im, max_dets_cat)
if "AP" in key:
title = "Average Precision"
_type = "(AP)"
Expand All @@ -550,7 +594,7 @@ def print_results(self):
else:
area_rng = "all"

print(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value))
print(template.format(title, _type, iou, area_rng, max_dets_str, cat_group_name, value))

def get_results(self):
if not self.results:
Expand All @@ -571,7 +615,8 @@ def __init__(self, iou_type):
self.rec_thrs = np.linspace(
0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True
)
self.max_dets = 300
self.max_dets = 300 # Max detections per image
self.max_dets_per_cat = -1 # Max detections per category
self.area_rng = [
[0 ** 2, 1e5 ** 2],
[0 ** 2, 32 ** 2],
Expand Down
51 changes: 45 additions & 6 deletions lvis/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,28 @@


class LVISResults(LVIS):
def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, dilation_ratio=0.02):
def __init__(
self,
lvis_gt,
results,
max_dets_per_cat=-1,
max_dets_per_im=300,
precompute_boundary=False,
dilation_ratio=0.02,
):
"""Constructor for LVIS results.
Args:
lvis_gt (LVIS class instance, or str containing path of
annotation file)
results (str containing path of result file or a list of dicts)
max_dets (int): max number of detections per image. The official
value of max_dets for LVIS is 300.
max_dets_per_cat (int): max number of detections per category. The
official value for the current version of the LVIS API is
infinite (i.e., -1). The official value for the 2021 LVIS
challenge is 10,000.
max_dets_per_im (int): max number of detections per image. The
official value for the current version of the LVIS API is 300.
The official value for the 2021 LVIS challenge is infinite
(i.e., -1).
precompute_boundary (bool): whether to precompute mask boundary before evaluation
dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
"""
Expand All @@ -25,7 +39,7 @@ def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, di
self.dataset = self._load_json(lvis_gt)
else:
raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt))

self.precompute_boundary = precompute_boundary
self.dilation_ratio = dilation_ratio

Expand All @@ -42,8 +56,12 @@ def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, di

assert isinstance(result_anns, list), "results is not a list."

if max_dets >= 0:
result_anns = self.limit_dets_per_image(result_anns, max_dets)
if max_dets_per_im >= 0:
result_anns = self.limit_dets_per_image(result_anns, max_dets_per_im)
self.max_dets_per_im = max_dets_per_im
self.max_dets_per_cat = max_dets_per_cat
if max_dets_per_cat >= 0:
result_anns = self.limit_dets_per_cat(result_anns, max_dets_per_cat)

if "bbox" in result_anns[0]:
for id, ann in enumerate(result_anns):
Expand Down Expand Up @@ -76,6 +94,27 @@ def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, di
set(img_ids_in_result) & set(self.get_img_ids())
), "Results do not correspond to current LVIS set."

def limit_dets_per_cat(self, anns, max_dets):
by_cat = defaultdict(list)
for ann in anns:
by_cat[ann["category_id"]].append(ann)
results = []
fewer_dets_cats = set()
for cat, cat_anns in by_cat.items():
if len(cat_anns) < max_dets:
fewer_dets_cats.add(cat)
results.extend(
sorted(cat_anns, key=lambda x: x["score"], reverse=True)[:max_dets]
)
if fewer_dets_cats:
self.logger.warning(
f"{len(fewer_dets_cats)} categories had less than {max_dets} "
f"detections!\n"
f"Outputting {max_dets} detections for each category will improve AP "
f"further."
)
return results

def limit_dets_per_image(self, anns, max_dets):
img_ann = defaultdict(list)
for ann in anns:
Expand Down
19 changes: 19 additions & 0 deletions test_challenge21.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import logging
from lvis import LVIS, LVISResults, LVISEval

# result and val files for 100 randomly sampled images.
ANNOTATION_PATH = "./data/lvis_val_100.json"
RESULT_PATH = "./data/lvis_results_100.json"

ANN_TYPE = 'bbox'

lvis_gt = LVIS(ANNOTATION_PATH)
lvis_dt = LVISResults(lvis_gt,
RESULT_PATH,
max_dets_per_cat=2,
max_dets_per_im=-1)
lvis_eval = LVISEval(lvis_gt, lvis_dt, ANN_TYPE)
lvis_eval.params.max_dets = -1
lvis_eval.params.max_dets_per_cat = 2
lvis_eval.run()
lvis_eval.print_results()

0 comments on commit 84ac272

Please sign in to comment.