From 84ac272f9167820173ddecb30585ada97d378efe Mon Sep 17 00:00:00 2001 From: Achal Dave Date: Wed, 23 Jun 2021 18:23:19 -0400 Subject: [PATCH] Implement Fixed AP with Boundary IoU (#25) * Implement Fixed AP * Implement mode="challenge2021" in LVISEval --- lvis/eval.py | 65 ++++++++++++++++++++++++++++++++++++++------- lvis/results.py | 51 ++++++++++++++++++++++++++++++----- test_challenge21.py | 19 +++++++++++++ 3 files changed, 119 insertions(+), 16 deletions(-) create mode 100644 test_challenge21.py diff --git a/lvis/eval.py b/lvis/eval.py index a50d5f0..6470816 100644 --- a/lvis/eval.py +++ b/lvis/eval.py @@ -12,20 +12,29 @@ class LVISEval: - def __init__(self, lvis_gt, lvis_dt, iou_type="segm", dilation_ratio=0.02): + def __init__(self, lvis_gt, lvis_dt, iou_type="segm", mode="default", dilation_ratio=0.02): """Constructor for LVISEval. Args: lvis_gt (LVIS class instance, or str containing path of annotation file) lvis_dt (LVISResult class instance, or str containing path of result file, or list of dict) - iou_type (str): segm or bbox evaluation + iou_type (str): segm, bbox, or boundary evaluation. Ignored if `mode` is set to + 'challenge2021'. dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal + mode (str): Either 'default' or 'challenge2021'. Specifying 'challenge2021' + uses iou_type=boundary and limits detections to 10,000 per class + (instead of 300 per image). """ self.logger = logging.getLogger(__name__) if iou_type not in ["bbox", "segm", "boundary"]: raise ValueError("iou_type: {} is not supported.".format(iou_type)) + if mode == "challenge2021": + iou_type = "boundary" + elif mode != "default": + raise ValueError("Unexpected mode: {}".format(mode)) + self.use_boundary_iou = iou_type == "boundary" if isinstance(lvis_gt, LVIS): @@ -37,8 +46,25 @@ def __init__(self, lvis_gt, lvis_dt, iou_type="segm", dilation_ratio=0.02): if isinstance(lvis_dt, LVISResults): self.lvis_dt = lvis_dt + if mode == "challenge2021": + assert self.lvis_dt.max_dets_per_im == -1, ( + "mode='challenge2021' specified with LVISResults object " + "containing incorrect max_dets_per_im (must be -1)." + ) + assert self.lvis_dt.max_dets_per_cat == 10000, ( + "mode='challenge2021' specified with LVISResults object " + "containing incorrect max_dets_per_cat (must be 10000)." + ) elif isinstance(lvis_dt, (str, list)): - self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt) + if mode == "default": + self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt) + elif mode == "challenge2021": + self.lvis_dt = LVISResults( + self.lvis_gt, + lvis_dt, + max_dets_per_cat=10000, + max_dets_per_im=-1 + ) else: raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt)) @@ -65,6 +91,9 @@ def __init__(self, lvis_gt, lvis_dt, iou_type="segm", dilation_ratio=0.02): self.params = Params(iou_type=iou_type) # parameters self.results = OrderedDict() self.ious = {} # ious between all gts and dts + if mode == "challenge2021": + self.params.max_dets = -1 + self.params.max_dets_per_cat = 10000 self.params.img_ids = sorted(self.lvis_gt.get_img_ids()) self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids()) @@ -495,7 +524,8 @@ def summarize(self): if not self.eval: raise RuntimeError("Please run accumulate() first.") - max_dets = self.params.max_dets + max_dets_im = self.params.max_dets + max_dets_cat = self.params.max_dets_per_cat self.results["AP"] = self._summarize('ap') self.results["AP50"] = self._summarize('ap', iou_thr=0.50) @@ -507,11 +537,18 @@ def summarize(self): self.results["APc"] = self._summarize('ap', freq_group_idx=1) self.results["APf"] = self._summarize('ap', freq_group_idx=2) - key = "AR@{}".format(max_dets) + # If max dets/cat is specified, update key to include both dets/im and dets/cat. + if max_dets_cat < 0: + key_suffix = max_dets_im + elif max_dets_im < 0: + key_suffix = "{}/cat".format(max_dets_cat) + else: # Both max dets/im and max dets/cat specified + key_suffix = "{}/im,{}/cat".format(max_dets_im, max_dets_cat) + key = "AR@{}".format(key_suffix) self.results[key] = self._summarize('ar') for area_rng in ["small", "medium", "large"]: - key = "AR{}@{}".format(area_rng[0], max_dets) + key = "AR{}@{}".format(area_rng[0], key_suffix) self.results[key] = self._summarize('ar', area_rng=area_rng) def run(self): @@ -521,10 +558,17 @@ def run(self): self.summarize() def print_results(self): - template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}" + template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={} catIds={:>3s}] = {:0.3f}" for key, value in self.results.items(): - max_dets = self.params.max_dets + max_dets_im = self.params.max_dets + max_dets_cat = self.params.max_dets_per_cat + if max_dets_cat < 0: + max_dets_str = "{:>3d}".format(max_dets_im) + elif max_dets_im < 0: + max_dets_str = "{}/cat".format(max_dets_cat) + else: + max_dets_str = "{}/im,{}/cat".format(max_dets_im, max_dets_cat) if "AP" in key: title = "Average Precision" _type = "(AP)" @@ -550,7 +594,7 @@ def print_results(self): else: area_rng = "all" - print(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value)) + print(template.format(title, _type, iou, area_rng, max_dets_str, cat_group_name, value)) def get_results(self): if not self.results: @@ -571,7 +615,8 @@ def __init__(self, iou_type): self.rec_thrs = np.linspace( 0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True ) - self.max_dets = 300 + self.max_dets = 300 # Max detections per image + self.max_dets_per_cat = -1 # Max detections per category self.area_rng = [ [0 ** 2, 1e5 ** 2], [0 ** 2, 32 ** 2], diff --git a/lvis/results.py b/lvis/results.py index 38ff1f5..ef21914 100644 --- a/lvis/results.py +++ b/lvis/results.py @@ -7,14 +7,28 @@ class LVISResults(LVIS): - def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, dilation_ratio=0.02): + def __init__( + self, + lvis_gt, + results, + max_dets_per_cat=-1, + max_dets_per_im=300, + precompute_boundary=False, + dilation_ratio=0.02, + ): """Constructor for LVIS results. Args: lvis_gt (LVIS class instance, or str containing path of annotation file) results (str containing path of result file or a list of dicts) - max_dets (int): max number of detections per image. The official - value of max_dets for LVIS is 300. + max_dets_per_cat (int): max number of detections per category. The + official value for the current version of the LVIS API is + infinite (i.e., -1). The official value for the 2021 LVIS + challenge is 10,000. + max_dets_per_im (int): max number of detections per image. The + official value for the current version of the LVIS API is 300. + The official value for the 2021 LVIS challenge is infinite + (i.e., -1). precompute_boundary (bool): whether to precompute mask boundary before evaluation dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal """ @@ -25,7 +39,7 @@ def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, di self.dataset = self._load_json(lvis_gt) else: raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt)) - + self.precompute_boundary = precompute_boundary self.dilation_ratio = dilation_ratio @@ -42,8 +56,12 @@ def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, di assert isinstance(result_anns, list), "results is not a list." - if max_dets >= 0: - result_anns = self.limit_dets_per_image(result_anns, max_dets) + if max_dets_per_im >= 0: + result_anns = self.limit_dets_per_image(result_anns, max_dets_per_im) + self.max_dets_per_im = max_dets_per_im + self.max_dets_per_cat = max_dets_per_cat + if max_dets_per_cat >= 0: + result_anns = self.limit_dets_per_cat(result_anns, max_dets_per_cat) if "bbox" in result_anns[0]: for id, ann in enumerate(result_anns): @@ -76,6 +94,27 @@ def __init__(self, lvis_gt, results, max_dets=300, precompute_boundary=False, di set(img_ids_in_result) & set(self.get_img_ids()) ), "Results do not correspond to current LVIS set." + def limit_dets_per_cat(self, anns, max_dets): + by_cat = defaultdict(list) + for ann in anns: + by_cat[ann["category_id"]].append(ann) + results = [] + fewer_dets_cats = set() + for cat, cat_anns in by_cat.items(): + if len(cat_anns) < max_dets: + fewer_dets_cats.add(cat) + results.extend( + sorted(cat_anns, key=lambda x: x["score"], reverse=True)[:max_dets] + ) + if fewer_dets_cats: + self.logger.warning( + f"{len(fewer_dets_cats)} categories had less than {max_dets} " + f"detections!\n" + f"Outputting {max_dets} detections for each category will improve AP " + f"further." + ) + return results + def limit_dets_per_image(self, anns, max_dets): img_ann = defaultdict(list) for ann in anns: diff --git a/test_challenge21.py b/test_challenge21.py new file mode 100644 index 0000000..43558de --- /dev/null +++ b/test_challenge21.py @@ -0,0 +1,19 @@ +import logging +from lvis import LVIS, LVISResults, LVISEval + +# result and val files for 100 randomly sampled images. +ANNOTATION_PATH = "./data/lvis_val_100.json" +RESULT_PATH = "./data/lvis_results_100.json" + +ANN_TYPE = 'bbox' + +lvis_gt = LVIS(ANNOTATION_PATH) +lvis_dt = LVISResults(lvis_gt, + RESULT_PATH, + max_dets_per_cat=2, + max_dets_per_im=-1) +lvis_eval = LVISEval(lvis_gt, lvis_dt, ANN_TYPE) +lvis_eval.params.max_dets = -1 +lvis_eval.params.max_dets_per_cat = 2 +lvis_eval.run() +lvis_eval.print_results()