From 34a1a3f2e863669e645f438bd83d9ad93ed4e810 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 17 May 2023 21:40:12 +0330 Subject: [PATCH 01/41] Keep only cocomap-related changes i.e. ObjectDetectionMap and its dependencies --- docs/source/metrics.rst | 2 + ignite/distributed/utils.py | 66 +- ignite/metrics/__init__.py | 4 + ignite/metrics/mean_average_precision.py | 664 +++++++++++++ ignite/metrics/vision/__init__.py | 3 + ignite/metrics/vision/object_detection_map.py | 208 +++++ requirements-dev.txt | 1 + tests/ignite/distributed/utils/__init__.py | 28 +- .../metrics/test_mean_average_precision.py | 500 ++++++++++ tests/ignite/metrics/vision/__init__.py | 0 .../vision/test_object_detection_map.py | 884 ++++++++++++++++++ 11 files changed, 2347 insertions(+), 13 deletions(-) create mode 100644 ignite/metrics/mean_average_precision.py create mode 100644 ignite/metrics/vision/__init__.py create mode 100644 ignite/metrics/vision/object_detection_map.py create mode 100644 tests/ignite/metrics/test_mean_average_precision.py create mode 100644 tests/ignite/metrics/vision/__init__.py create mode 100644 tests/ignite/metrics/vision/test_object_detection_map.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index afc477f457e..4fe17aee67c 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -329,11 +329,13 @@ Complete list of metrics Frequency Loss MeanAbsoluteError + MeanAveragePrecision MeanPairwiseDistance MeanSquaredError metric.Metric metrics_lambda.MetricsLambda MultiLabelConfusionMatrix + ObjectDetectionMAP precision.Precision PSNR recall.Recall diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index e597fa0b436..aaa8887ac7c 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -1,7 +1,8 @@ +import itertools import socket from contextlib import contextmanager from functools import wraps -from typing import Any, Callable, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union import torch @@ -350,22 +351,65 @@ def all_reduce( return _model.all_reduce(tensor, op, group=group) +def _all_gather_tensors_with_shapes( + tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None +) -> List[torch.Tensor]: + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + + if isinstance(group, list) and all(isinstance(item, int) for item in group): + group = _model.new_group(group) + + if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group): + return [tensor] + + max_shape = torch.tensor(shapes).amax(dim=1) + padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist() + padded_tensor = torch.nn.functional.pad( + tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes)))) + ) + all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) # .split(max_shape[0], dim=0) + return [ + all_padded_tensors[ + [ + slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size) + for dim, dim_size in enumerate(shape) + ] + ] + for rank, shape in enumerate(shapes) + if group is None or rank in group + ] + + def all_gather( - tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None -) -> Union[torch.Tensor, float, List[float], List[str]]: + tensor: Union[torch.Tensor, float, str], + group: Optional[Union[Any, List[int]]] = None, + tensor_different_shape: bool = False, +) -> Union[torch.Tensor, float, List[float], List[str], List[torch.Tensor]]: """Helper method to perform all gather operation. Args: - tensor: tensor or number or str to collect across participating processes. + tensor: tensor or number or str to collect across participating processes. If tensor, it should have + the same number of dimensions across processes. group: list of integer or the process group for each backend. If None, the default process group will be used. + tensor_different_shape: If True, it accounts for difference in input shape across processes. In this case, it + induces more collective operations. If False, `tensor` should have the same shape across processes. + Ignored when `tensor` is not a tensor. Default False. + Returns: - torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or - torch.Tensor of shape ``(world_size, )`` if input is a number or - List of strings if input is a string + If input is a tensor, returns a torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` + if ``tensor_different_shape = False``, otherwise a list of tensors with length ``world_size``(if ``group`` + is `None`) or `len(group)`. If current process does not belong to `group`, a list with `tensor` as its only + item is retured. + If input is a number, a torch.Tensor of shape ``(world_size, )`` is returned and finally a list of strings + is returned if input is a string. .. versionchanged:: 0.4.11 added ``group`` + + .. versionchanged:: 0.5.1 + added ``tensor_different_shape`` """ if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) @@ -373,6 +417,14 @@ def all_gather( if isinstance(group, list) and all(isinstance(item, int) for item in group): group = _model.new_group(group) + if isinstance(tensor, torch.Tensor) and tensor_different_shape: + if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group): + return [tensor] + all_shapes: torch.Tensor = _model.all_gather(torch.tensor(tensor.shape), group=group).view( + -1, len(tensor.shape) + ) + return _all_gather_tensors_with_shapes(tensor, all_shapes.tolist(), group=group) + return _model.all_gather(tensor, group=group) diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index d001436a3ad..d5d2bd56078 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -9,6 +9,7 @@ from ignite.metrics.gan.inception_score import InceptionScore from ignite.metrics.loss import Loss from ignite.metrics.mean_absolute_error import MeanAbsoluteError +from ignite.metrics.mean_average_precision import MeanAveragePrecision from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance from ignite.metrics.mean_squared_error import MeanSquaredError from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage @@ -23,6 +24,7 @@ from ignite.metrics.running_average import RunningAverage from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy +from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP __all__ = [ "Metric", @@ -58,4 +60,6 @@ "Rouge", "RougeN", "RougeL", + "MeanAveragePrecision", + "ObjectDetectionMAP", ] diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py new file mode 100644 index 00000000000..fe7c76c9f98 --- /dev/null +++ b/ignite/metrics/mean_average_precision.py @@ -0,0 +1,664 @@ +import itertools +import warnings +from collections import defaultdict +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from typing_extensions import Literal + +import ignite.distributed as idist +from ignite.distributed.utils import _all_gather_tensors_with_shapes +from ignite.metrics.metric import reinit__is_reduced +from ignite.metrics.recall import _BasePrecisionRecall +from ignite.utils import to_onehot + + +class MeanAveragePrecision(_BasePrecisionRecall): + _tp: Dict[int, List[torch.Tensor]] + _fp: Dict[int, List[torch.Tensor]] + _scores: Union[Dict[int, List[torch.Tensor]], List[torch.Tensor]] + _P: Union[Dict[int, int], List[torch.Tensor]] + + def __init__( + self, + rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + average_operand: Optional[Literal["precision", "max-precision"]] = "precision", + class_mean: Optional[Literal["micro", "macro", "weighted", "with_other_dims"]] = "macro", + classification_is_multilabel: bool = False, + allow_multiple_recalls_at_single_threshold: bool = False, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + ) -> None: + r"""Calculate the mean average precision metric i.e. mean of the averaged-over-recall precision for detection + and classification tasks. + + Mean average precision attempts to give a measure of detector or classifier precision at various + sensivity levels a.k.a recall thresholds. This is done by summing precisions at different recall + thresholds weighted by the change in recall, as if the area under precision-recall curve is being computed. + Mean average precision is the computed by taking the mean of this average precision over different classes + and possibly some additional dimensions in the detection task. + + For detection tasks user must subclass this metric and implement its :meth:`do_matching` + method to provide the metric with desired matching logic. Then this method is called internally in + :meth:`update` method on prediction-target pairs. For classification, all the binary, multiclass and + multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be set to true. + + `mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean`` + determines how to take this mean. In the detection tasks, it's possible to take mean of the average precision + in other respects as well e.g. IoU threshold in an object detection task. To this end, average precision + corresponding to each value of IoU thresholds should get measured in :meth:`do_matching`. Please refer to + :meth:`do_matching` for more info on this. + + Args: + rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. + It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need + to be sorted. If missing, thresholds are considered automatically using the data. + average_operand: one of values "precision" or "max-precision". In the former case, the precision at a + recall threshold is used for that threshold: + + .. math:: + \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) P_k + + :math:`r` stands for recall thresholds and :math:`P` for precision values. :math:`r_0` is set to zero. + + In the latter case, the maximum precision across thresholds greater or equal a recall threshold is + considered as the summation operand; In other words, the precision peek across lower or equall + sensivity levels is used for a recall threshold: + + .. math:: + \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) max(P_{k:}) + + Default is "precision". + class_mean: how to compute mean of the average precision across classes or incorporate class + dimension into computing precision. It's ignored in binary classification. Available options are + + None + An 1-dimensional tensor of mean (taken across additional mean dimensions) average precision per class + is returned. If there's no ground truth sample for a class, ``0`` is returned for that. + + 'micro' + Precision is computed counting stats of classes/labels altogether. This option + incorporates class in the very precision measurement. + + .. math:: + \text{Micro P} = \frac{\sum_{c=1}^C TP_c}{\sum_{c=1}^C TP_c+FP_c} + + where :math:`C` is the number of classes/labels. :math:`c` in :math:`TP_c` + and :math:`FP_c` means that the terms are computed for class/label :math:`c` (in a one-vs-rest + sense in multiclass case). + + For multiclass inputs, this is equivalent with mean average accuracy. + + 'weighted' + like macro but considers class/label imbalance. For multiclass input, + it computes AP for each class then returns mean of them weighted by + support of classes (number of actual samples in each class). For multilabel input, + it computes AP for each label then returns mean of them weighted by support + of labels (number of actual positive samples in each label). + + 'macro' + computes macro precision which is unweighted mean of AP computed across classes/labels. Default. + + 'with_other_dims' + Mean over class dimension is taken with additional mean dimensions all at once, despite macro and + weighted in which mean over additional dimensions is taken beforehand. Only available in detection. + + Note: + Please note that classes with no ground truth are not considered into the mean in detection. + + classification_is_multilabel: Used in classification task and determines if the data + is multilabel or not. Default False. + allow_multiple_recalls_at_single_threshold: When there are predictions with the same scores, it's faster to + consider those predictions associated with different thresholds in the course of measuring recall + values, but it's not logically correct since those predictions are associated with a single threshold, + thus outputing a single recall value. This option is added mainly due to some downstream mAP metrics + which allow such a thing in their computation e.g. pycocotools' mAP. Default False. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. This metric requires the output + as ``(y_pred, y)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + """ + if rec_thresholds is not None: + self.rec_thresholds: Optional[torch.Tensor] = self._setup_thresholds(rec_thresholds, "rec_thresholds") + else: + self.rec_thresholds = None + + if average_operand not in ("precision", "max-precision"): + raise ValueError(f"Wrong `average_operand` parameter, given {average_operand}") + self.average_operand = average_operand + + if class_mean is not None and class_mean not in ("micro", "macro", "weighted", "with_other_dims"): + raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") + self.class_mean = class_mean + + self.allow_multiple_recalls_at_single_threshold = allow_multiple_recalls_at_single_threshold + + super(_BasePrecisionRecall, self).__init__( + output_transform=output_transform, is_multilabel=classification_is_multilabel, device=device + ) + + if self._task == "classification" and self.class_mean == "with_other_dims": + raise ValueError("class_mean 'with_other_dims' is not compatible with classification.") + + def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor: + if isinstance(thresholds, Sequence): + thresholds = torch.tensor(thresholds) + + if isinstance(thresholds, torch.Tensor): + if thresholds.ndim != 1: + raise ValueError( + f"{threshold_type} should be a one-dimensional tensor or a sequence of floats" + f", given a {thresholds.ndim}-dimensional tensor." + ) + thresholds = thresholds.sort().values + else: + raise TypeError(f"{threshold_type} should be a sequence of floats or a tensor, given {type(thresholds)}.") + + if min(thresholds) < 0 or max(thresholds) > 1: + raise ValueError(f"{threshold_type} values should be between 0 and 1, given {thresholds}") + + return cast(torch.Tensor, thresholds) + + @reinit__is_reduced + def reset(self) -> None: + """ + Reset method of the metric + """ + super(_BasePrecisionRecall, self).reset() + if self.do_matching.__func__ == MeanAveragePrecision.do_matching: # type: ignore[attr-defined] + self._task: Literal["classification", "detection"] = "classification" + else: + self._task = "detection" + self._tp = defaultdict(lambda: []) + self._fp = defaultdict(lambda: []) + if self._task == "detection": + self._scores = defaultdict(lambda: []) + self._P = defaultdict(lambda: 0) + self._num_classes = 0 + else: + self._scores = [] + self._P = [] + + def _check_binary_multilabel_cases(self, output: Sequence[torch.Tensor]) -> None: + # Ignore the check in `_BaseClassification` since `y_pred` consists of probabilities here. + _, y = output + if not torch.equal(y, y**2): + raise ValueError("For binary cases, y must be comprised of 0's and 1's.") + + def _check_type(self, output: Sequence[torch.Tensor]) -> None: + super(_BasePrecisionRecall, self)._check_type(output) + y_pred, y = output + if y_pred.dtype in (torch.int, torch.long): + raise TypeError(f"`y_pred` should be a float tensor, given {y_pred.dtype}") + if self._type == "multiclass" and y.dtype != torch.long: + warnings.warn("`y` should be of dtype long when entry type is multiclass", RuntimeWarning) + + def _check_matching_output_shape( + self, tps: Dict[int, torch.Tensor], fps: Dict[int, torch.Tensor], scores: Dict[int, torch.Tensor] + ) -> None: + if not (tps.keys() == fps.keys() == scores.keys()): + raise ValueError( + "Returned TP, FP and scores dictionaries from do_matching should have" + f" the same keys (classes), given {tps.keys()}, {fps.keys()} and {scores.keys()}" + ) + try: + cls = list(tps.keys()).pop() + except IndexError: # No prediction + pass + else: + if tps[cls].dtype not in (torch.bool, torch.uint8): + raise TypeError(f"Tensors in TP and FP dictionaries should be boolean or uint8, given {tps[cls].dtype}") + + if tps[cls].size(-1) != fps[cls].size(-1) != scores[cls].size(0): + raise ValueError( + "Sample dimension of tensors in TP, FP and scores should have equal size per class," + f"given {tps[cls].size(-1)}, {fps[cls].size(-1)} and {scores[cls].size(-1)} for class {cls}" + " respectively." + ) + for self_tp_or_fp, new_tp_or_fp, name in [(self._tp, tps, "TP"), (self._fp, fps, "FP")]: + new_tp_or_fp.keys() + try: + cls = (self_tp_or_fp.keys() & new_tp_or_fp.keys()).pop() + except KeyError: + pass + else: + if self_tp_or_fp[cls][-1].shape[:-1] != new_tp_or_fp[cls].shape[:-1]: + raise ValueError( + f"Tensors in returned {name} from do_matching should not change in shape " + "except possibly in the last dimension which is the dimension of samples. Given " + f"{self_tp_or_fp[cls][-1].shape} and {new_tp_or_fp[cls].shape}" + ) + + def _classification_prepare_output( + self, y_pred: torch.Tensor, y: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Prepares and returns scores and P tensor. Input and output shapes of the method is as follows. + + ========== =========== ============ + ``y_pred`` + ----------------------------------- + Data type Input shape Output shape + ========== =========== ============ + Binary (N, ...) (1, N * ...) + Multilabel (N, C, ...) (C, N * ...) + Multiclass (N, C, ...) (C, N * ...) + ========== =========== ============ + + ========== =========== ============ + ``y`` + ----------------------------------- + Data type Input shape Output shape + ========== =========== ============ + Binary (N, ...) (1, N * ...) + Multilabel (N, C, ...) (C, N * ...) + Multiclass (N, ...) (N * ...) + ========== =========== ============ + """ + + if self._type == "multilabel": + num_classes = y_pred.size(1) + scores = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) + P = torch.transpose(y, 1, 0).reshape(num_classes, -1) + elif self._type == "binary": + P = y.view(1, -1) + scores = y_pred.view(1, -1) + else: # Multiclass + num_classes = y_pred.size(1) + if y.max() + 1 > num_classes: + raise ValueError( + f"y_pred contains fewer classes than y. Number of classes in prediction is {num_classes}" + f" and an element in y has invalid class = {y.max().item() + 1}." + ) + P = y.view(-1) + scores = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) + + return scores, P + + def do_matching( + self, pred: Any, target: Any + ) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]: + r""" + Matching logic holder of the metric for detection tasks. + + User must implement this method by subclassing the metric. There is no constraint on type and shape of + ``pred`` and ``target``, but the method should return a quadrople of dictionaries containing TP, FP, + P (actual positive) counts and scores for each class respectively. Please note that class numbers start from + zero. + + Values in TP and FP are (m+1)-dimensional tensors of type ``bool`` or ``uint8`` and shape + (D\ :sub:`1`, D\ :sub:`2`, ..., D\ :sub:`m`, n\ :sub:`cls`) in which D\ :sub:`i`\ 's are possible additional + dimensions (excluding the class dimension) mean of the average precision is taken over. n\ :sub:`cls` is the + number of predictions for class `cls` which is the same for TP and FP. + + Note: + TP and FP values are stored as uint8 tensors internally to avoid bool-to-uint8 copies before collective + operations, as PyTorch colective operations `do not `_ + support boolean tensors, at least on Gloo backend. + + + P counts contains the number of ground truth samples for each class. Finally, the values in scores are 1-dim + tensors of shape (n\ :sub:`cls`,) containing score or confidence of the predictions (doesn't need to be in + [0,1]). If there is no prediction or ground truth for a class, it could be absent from (TP, FP, scores) and P + dictionaries respectively. + + Args: + pred: First member of :meth:`update`'s input is given as this argument. There's no constraint on its type + and shape. + target: Second member of :meth:`update`'s input is given as this argument. There's no constraint on its type + and shape. + + Returns: + `(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores. + """ + raise NotImplementedError( + "Please subclass MeanAveragePrecision and implement `do_matching` method" " to use the metric in detection." + ) + + @reinit__is_reduced + def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor]]) -> None: + """Metric update function using prediction and target. + + Args: + output: a binary tuple. It should consist of prediction and target tensors in the classification case but + for detection it is the same as the implemented-by-user :meth:`do_matching`. + + For classification, this metric follows the same rules on ``output`` members shape as the + :meth:`Precision.update ` except for ``y_pred`` of binary and multilabel + data which should be comprised of positive class probabilities here. + """ + + if self._task == "classification": + self._check_shape(output) + prediction, target = output[0].detach(), output[1].detach() + self._check_type((prediction, target)) + scores, P = self._classification_prepare_output(prediction, target) + cast(List[torch.Tensor], self._scores).append(scores.to(self._device)) + cast(List[torch.Tensor], self._P).append( + P.to(self._device, dtype=torch.uint8 if self._type != "multiclass" else torch.long) + ) + else: + tps, fps, ps, scores_dict = self.do_matching(output[0], output[1]) + self._check_matching_output_shape(tps, fps, scores_dict) + for cls in tps: + self._tp[cls].append(tps[cls].to(device=self._device, dtype=torch.uint8)) + self._fp[cls].append(fps[cls].to(device=self._device, dtype=torch.uint8)) + cast(Dict[int, List[torch.Tensor]], self._scores)[cls].append(scores_dict[cls].to(self._device)) + for cls in ps: + cast(Dict[int, int], self._P)[cls] += ps[cls] + classes = tps.keys() | ps.keys() + if classes: + self._num_classes = max(max(classes) + 1, self._num_classes) + + def _measure_recall_and_precision( + self, TP: torch.Tensor, FP: Union[torch.Tensor, None], scores: torch.Tensor, P: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Measuring recall & precision which is the common operation among different settings of the metric. + + Shape of function inputs and return values follow the table below. C is the number of classes, 1 for binary + data. N\ :sub:`pred` is the number of detections or predictions which is the same as the number of samples in + classification task. ``...`` stands for the additional dimensions in the detection task. Finally, + \#unique scores represents number of unique scores in ``scores`` which is actually the number of thresholds. + + This method is called on a per class basis in the detection task and if + ``allow_multiple_recalls_at_single_threshold=False``. + + =========================== ================================== =================================== + Detection task + -------------------------------------------------------------------------------------------------- + **Object**/ **Condition** ``allow_multiple_recalls...=True`` ``allow_multiple_recalls...=False`` + =========================== ================================== =================================== + TP and FP (..., N\ :sub:`pred`) (..., N\ :sub:`pred`) + scores (N\ :sub:`pred`,) (N\ :sub:`pred`,) + P () (A single float) () (A single float) + recall (..., N\ :sub:`pred`) (..., \#unique scores) + precision (..., N\ :sub:`pred`) (..., \#unique scores) + =========================== ===================== + + =========================== ================================== =================================== + Classification task + -------------------------------------------------------------------------------------------------- + **Object**/ **Condition** ``allow_multiple_recalls...=True`` ``allow_multiple_recalls...=False`` + =========================== ================================== =================================== + TP (C, N\ :sub:`pred`) (N\ :sub:`pred`,) + FP (C, N\ :sub:`pred`) None (FP is computed here to be + faster) + scores (C, N\ :sub:`pred`) (N\ :sub:`pred`,) + P (C,) () (A single float) + recall (C, N\ :sub:`pred`) (\#unique scores,) + precision (C, N\ :sub:`pred`) (\#unique scores,) + =========================== ================================== =================================== + + Returns: + `(recall, precision)` + """ + indices = torch.argsort(scores, dim=-1, stable=True, descending=True) + tp = TP.take_along_dim(indices, dim=-1) if self._task == "classification" else TP[..., indices] + tp_summation = tp.cumsum(dim=-1).double() + if self._task == "detection" or self.allow_multiple_recalls_at_single_threshold: + fp = ( + cast(torch.Tensor, FP).take_along_dim(indices, dim=-1) + if self._task == "classification" + else cast(torch.Tensor, FP)[..., indices] + ) + fp_summation = fp.cumsum(dim=-1).double() + if not self.allow_multiple_recalls_at_single_threshold: + # Adopted from Scikit-learn's implementation + unique_scores_indices = torch.nonzero( + scores.take_along_dim(indices).diff(append=(scores.max() + 1).unsqueeze(dim=0)), as_tuple=True + )[0] + tp_summation = tp_summation[..., unique_scores_indices] + if self._task == "classification": + fp_summation = (unique_scores_indices + 1) - tp_summation + else: + fp_summation = fp_summation[..., unique_scores_indices] + + if self._task == "classification" and self.allow_multiple_recalls_at_single_threshold: + recall = torch.where(P == 0, 1, tp_summation.T / P).T + elif self._task == "classification" and P == 0: + recall = torch.ones_like(tp_summation, device=self._device, dtype=torch.bool) + else: + recall = tp_summation / P + # precision = tp_summation / (fp_summation + tp_summation + torch.finfo(torch.double).eps) + # or + predicted_positive = tp_summation + fp_summation + precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) + return recall, precision + + def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: + """Measuring average precision which is the common operation among different settings of the metric. + + Args: + recall: n-dimensional tensor whose last dimension is the dimension of the samples. Should be ordered in + ascending order in its last dimension. + precision: like ``recall`` in the shape. + + Returns: + average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. + """ + precision_integrand = ( + precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average_operand == "max-precision" else precision + ) + if self.rec_thresholds is not None: + rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1)) + rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) + precision_integrand = precision_integrand.take_along_dim( + rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 + ).where(rec_thresh_indices != recall.size(-1), 0) + recall = rec_thresholds + recall_differential = recall.diff( + dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=self._device, dtype=torch.double) + ) + return torch.sum(recall_differential * precision_integrand, dim=-1) + + def compute(self) -> Union[torch.Tensor, float]: + """ + Compute method of the metric + """ + num_classes = int(idist.all_reduce(self._num_classes or 0, "MAX")) + if not num_classes: + return 0.0 + + if self._task == "detection": + P = cast( + torch.Tensor, + idist.all_reduce(torch.tensor(list(map(self._P.__getitem__, range(num_classes))), device=self._device)), + ) + num_preds = torch.tensor( + [sum([tp.shape[-1] for tp in self._tp[cls]]) if self._tp[cls] else 0 for cls in range(num_classes)], + device=self._device, + ) + num_preds_per_class_across_ranks = torch.stack( + cast(torch.Tensor, idist.all_gather(num_preds)).split(split_size=num_classes) + ) + if num_preds_per_class_across_ranks.sum() == 0: + return ( + 0.0 + if self.class_mean is not None + else torch.zeros((num_classes,), dtype=torch.double, device=self._device) + ) + a_nonempty_rank, its_class_with_pred = list(zip(*torch.where(num_preds_per_class_across_ranks != 0))).pop(0) + a_nonempty_rank = a_nonempty_rank.item() + its_class_with_pred = its_class_with_pred.item() + mean_dimensions_shape = cast( + torch.Tensor, + idist.broadcast( + torch.tensor(self._tp[its_class_with_pred][-1].shape[:-1], device=self._device) + if idist.get_rank() == a_nonempty_rank + else None, + a_nonempty_rank, + safe_mode=True, + ), + ).tolist() + + if self.class_mean != "micro": + shapes_across_ranks = { + cls: [ + (*mean_dimensions_shape, num_pred_in_rank) + for num_pred_in_rank in num_preds_per_class_across_ranks[:, cls] + ] + for cls in range(num_classes) + } + TP = { + cls: torch.cat( + _all_gather_tensors_with_shapes( + torch.cat(self._tp[cls], dim=-1) + if self._tp[cls] + else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), + shapes_across_ranks[cls], + ), + dim=-1, + ) + for cls in range(num_classes) + } + FP = { + cls: torch.cat( + _all_gather_tensors_with_shapes( + torch.cat(self._fp[cls], dim=-1) + if self._fp[cls] + else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), + shapes_across_ranks[cls], + ), + dim=-1, + ) + for cls in range(num_classes) + } + scores = { + cls: torch.cat( + _all_gather_tensors_with_shapes( + torch.cat(cast(List[torch.Tensor], self._scores[cls])) + if self._scores[cls] + else torch.tensor([], dtype=torch.double, device=self._device), + num_preds_per_class_across_ranks[:, [cls]].tolist(), + ) + ) + for cls in range(num_classes) + } + + average_precisions = -torch.ones( + (num_classes, *(mean_dimensions_shape if self.class_mean == "with_other_dims" else ())), + device=self._device, + dtype=torch.double, + ) + for cls in range(num_classes): + if P[cls] == 0: + continue + if TP[cls].size(-1) == 0: + average_precisions[cls] = 0 + continue + recall, precision = self._measure_recall_and_precision(TP[cls], FP[cls], scores[cls], P[cls]) + average_precision_for_cls_across_other_dims = self._measure_average_precision(recall, precision) + if self.class_mean != "with_other_dims": + average_precisions[cls] = average_precision_for_cls_across_other_dims.mean() + else: + average_precisions[cls] = average_precision_for_cls_across_other_dims + if self.class_mean is None: + average_precisions[average_precisions == -1] = 0 + return average_precisions + elif self.class_mean == "weighted": + return torch.dot(P.double(), average_precisions) / P.sum() + else: + return average_precisions[average_precisions > -1].mean() + else: + num_preds_across_ranks = num_preds_per_class_across_ranks.sum(dim=1) + shapes_across_ranks_in_micro = [ + (*mean_dimensions_shape, num_preds_in_rank.item()) for num_preds_in_rank in num_preds_across_ranks + ] + TP_micro = torch.cat( + _all_gather_tensors_with_shapes( + torch.cat(list(itertools.chain(*map(self._tp.__getitem__, range(num_classes)))), dim=-1).to( + torch.uint8 + ) + if num_preds_across_ranks[idist.get_rank()] + else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), + shapes_across_ranks_in_micro, + ), + dim=-1, + ).bool() + FP_micro = torch.cat( + _all_gather_tensors_with_shapes( + torch.cat(list(itertools.chain(*map(self._fp.__getitem__, range(num_classes)))), dim=-1).to( + torch.uint8 + ) + if num_preds_across_ranks[idist.get_rank()] + else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), + shapes_across_ranks_in_micro, + ), + dim=-1, + ).bool() + scores_micro = torch.cat( + _all_gather_tensors_with_shapes( + torch.cat( + list( + itertools.chain( + *map( + cast(Dict[int, List[torch.Tensor]], self._scores).__getitem__, + range(num_classes), + ) + ) + ) + ) + if num_preds_across_ranks[idist.get_rank()] + else torch.tensor([], dtype=torch.double, device=self._device), + num_preds_across_ranks.unsqueeze(dim=-1).tolist(), + ) + ) + P = P.sum() + recall, precision = self._measure_recall_and_precision(TP_micro, FP_micro, scores_micro, P) + return self._measure_average_precision(recall, precision).mean() + else: + rank_P = ( + torch.cat(cast(List[torch.Tensor], self._P), dim=-1) + if self._P + else ( + torch.empty((num_classes, 0), dtype=torch.uint8, device=self._device) + if self._type == "multilabel" + else torch.tensor( + [], dtype=torch.long if self._type == "multiclass" else torch.uint8, device=self._device + ) + ) + ) + P = torch.cat(cast(List[torch.Tensor], idist.all_gather(rank_P, tensor_different_shape=True)), dim=-1) + scores_classification = torch.cat( + cast( + List[torch.Tensor], + idist.all_gather( + torch.cat(cast(List[torch.Tensor], self._scores), dim=-1) + if self._scores + else ( + torch.tensor([], device=self._device) + if self._type == "binary" + else torch.empty((num_classes, 0), dtype=torch.double, device=self._device) + ), + tensor_different_shape=True, + ), + ), + dim=-1, + ) + if self._type == "multiclass": + P = to_onehot(P, num_classes=self._num_classes).T + if self.class_mean == "micro": + P = P.reshape(1, -1) + scores_classification = scores_classification.view(1, -1) + P_count = P.sum(dim=-1) + if self.allow_multiple_recalls_at_single_threshold: + recall, precision = self._measure_recall_and_precision(P, 1 - P, scores_classification, P_count) + average_precisions = self._measure_average_precision(recall, precision) + else: + average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double) + for cls in range(len(P_count)): + recall, precision = self._measure_recall_and_precision( + P[cls], None, scores_classification[cls], P_count[cls] + ) + average_precisions[cls] = self._measure_average_precision(recall, precision) + if self._type == "binary": + return average_precisions.item() + if self.class_mean is None: + return average_precisions + elif self.class_mean == "weighted": + return torch.sum(P_count * average_precisions) / P_count.sum() + else: + return average_precisions.mean() diff --git a/ignite/metrics/vision/__init__.py b/ignite/metrics/vision/__init__.py new file mode 100644 index 00000000000..f351d5b339f --- /dev/null +++ b/ignite/metrics/vision/__init__.py @@ -0,0 +1,3 @@ +from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP + +__all__ = ["ObjectDetectionMAP"] diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py new file mode 100644 index 00000000000..35bc1ba5ee1 --- /dev/null +++ b/ignite/metrics/vision/object_detection_map.py @@ -0,0 +1,208 @@ +from typing import Callable, cast, Dict, Literal, Optional, Sequence, Tuple, Union + +import torch + +from ignite.metrics.mean_average_precision import MeanAveragePrecision + + +class ObjectDetectionMAP(MeanAveragePrecision): + def __init__( + self, + iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + flavor: Optional[Literal["COCO",]] = "COCO", + rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + ) -> None: + r"""Calculate the mean average precision for evaluating an object detector. + + The input to metric's ``update`` method should be a binary tuple of str-to-tensor dictionaries, (y_pred, y), + which their items are as follows. N\ :sub:`det` and N\ :sub:`gt` are number of detections and ground truths + respectively. + + ======= ================== ================================================= + **y_pred items** + ------------------------------------------------------------------------------ + Key Value shape Description + ======= ================== ================================================= + 'bbox' (N\ :sub:`det`, 4) Bounding boxes of form (x1, y1, x2, y2) + containing top left and bottom right coordinates. + 'score' (N\ :sub:`det`,) Confidence score of detections. + 'label' (N\ :sub:`det`,) Predicted category number of detections. + ======= ================== ================================================= + + ========= ================== ================================================= + **y items** + ------------------------------------------------------------------------------ + Key Value shape Description + ========= ================== ================================================= + 'bbox' (N\ :sub:`gt`, 4) Bounding boxes of form (x1, y1, x2, y2) + containing top left and bottom right coordinates. + 'label' (N\ :sub:`gt`,) Category number of ground truths. + 'iscrowd' (N\ :sub:`gt`,) Whether ground truth boxes are crowd ones or not. + ========= ================== ================================================= + + Args: + iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision. + Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. + flavor: string values so that metric computation recipe correspond to its respective flavor. For now, only + available option is 'COCO'. Default 'COCO'. + rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. + Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + """ + try: + from torchvision.ops.boxes import _box_inter_union, box_area + + def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.BoolTensor) -> torch.Tensor: + inter, union = _box_inter_union(pred_boxes, gt_boxes) + union[:, iscrowd] = box_area(pred_boxes).reshape(-1, 1) + iou = inter / union + iou[iou.isnan()] = 0 + return iou + + self.box_iou = box_iou + except ImportError: + raise ModuleNotFoundError("This metric requires torchvision to be installed.") + + if flavor != "COCO": + raise ValueError(f"Currently, the only available flavor for ObjectDetectionMAP is 'COCO', given {flavor}") + self.flavor = flavor + + if iou_thresholds is None: + iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double) + + self.iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds") + + if rec_thresholds is None: + rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=torch.double) + + super(ObjectDetectionMAP, self).__init__( + rec_thresholds=rec_thresholds, + average_operand="max-precision" if flavor == "COCO" else "precision", + class_mean="with_other_dims", + allow_multiple_recalls_at_single_threshold=flavor == "COCO", + output_transform=output_transform, + device=device, + ) + + def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: + """Measuring average precision. + This method is overriden since :math:`1/#recall_thresholds` is used instead of :math:`r_k - r_{k-1}` + as the recall differential in COCO flavor. + + Args: + recall: n-dimensional tensor whose last dimension is the dimension of the samples. Should be ordered in + ascending order in its last dimension. + precision: like ``recall`` in the shape. + + Returns: + average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. + """ + if self.flavor != "COCO": + return super()._measure_average_precision(recall, precision) + + precision_integrand = ( + precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average_operand == "max-precision" else precision + ) + rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) + rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) + precision_integrand = precision_integrand.take_along_dim( + rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 + ).where(rec_thresh_indices != recall.size(-1), 0) + return torch.sum(precision_integrand, dim=-1) / len(cast(torch.Tensor, self.rec_thresholds)) + + def compute(self) -> Union[torch.Tensor, float]: + if not sum(cast(Dict[int, int], self._P).values()) and self.flavor == "COCO": + return -1 + return super().compute() + + def do_matching( + self, pred: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor] + ) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]: + """ + Matching logic of object detection mAP. + """ + labels = target["labels"].detach() + pred_labels = pred["labels"].detach() + pred_scores = pred["scores"].detach() + categories = list(set(labels.int().tolist() + pred_labels.int().tolist())) + + pred_boxes = pred["bbox"] + gt_boxes = target["bbox"] + + is_crowd = target["iscrowd"] + + tp: Dict[int, torch.Tensor] = {} + fp: Dict[int, torch.Tensor] = {} + P: Dict[int, int] = {} + scores: Dict[int, torch.Tensor] = {} + + for category in categories: + class_index_gt = labels == category + num_category_gt = class_index_gt.sum() + category_is_crowd = is_crowd[class_index_gt] + if num_category_gt: + P[category] = num_category_gt - category_is_crowd.sum() + + class_index_dt = pred_labels == category + if not class_index_dt.any(): + continue + + scores[category] = pred_scores[class_index_dt] + + category_tp = torch.zeros( + (len(self.iou_thresholds), class_index_dt.sum().item()), dtype=torch.uint8, device=self._device + ) + category_fp = torch.zeros( + (len(self.iou_thresholds), class_index_dt.sum().item()), dtype=torch.uint8, device=self._device + ) + if num_category_gt: + class_iou = self.box_iou( + pred_boxes[class_index_dt], + gt_boxes[class_index_gt], + cast(torch.BoolTensor, category_is_crowd.bool()), + ) + class_maximum_iou = class_iou.max() + category_pred_idx_sorted_by_decreasing_score = torch.argsort( + pred_scores[class_index_dt], stable=True, descending=True + ).tolist() + for thres_idx, iou_thres in enumerate(self.iou_thresholds): + if iou_thres <= class_maximum_iou: + matched_gt_indices = set() + for pred_idx in category_pred_idx_sorted_by_decreasing_score: + match_iou, match_idx = min(iou_thres, 1 - 1e-10), -1 + for gt_idx in range(num_category_gt): + if (class_iou[pred_idx][gt_idx] < iou_thres) or ( + gt_idx in matched_gt_indices and torch.logical_not(category_is_crowd[gt_idx]) + ): + continue + if match_idx == -1 or ( + class_iou[pred_idx][gt_idx] >= match_iou + and torch.logical_or( + torch.logical_not(category_is_crowd[gt_idx]), category_is_crowd[match_idx] + ) + ): + match_iou = class_iou[pred_idx][gt_idx] + match_idx = gt_idx + if match_idx != -1: + matched_gt_indices.add(match_idx) + category_tp[thres_idx][pred_idx] = torch.logical_not(category_is_crowd[match_idx]) + else: + category_fp[thres_idx][pred_idx] = 1 + else: + category_fp[thres_idx] = 1 + else: + category_fp[:, :] = 1 + + tp[category] = category_tp + fp[category] = category_fp + + return tp, fp, P, scores diff --git a/requirements-dev.txt b/requirements-dev.txt index 182a4057bc1..6b92f894f65 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -32,3 +32,4 @@ nltk pandas gymnasium mkl +pycocotools diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 798cdd31183..d2a9f177d9f 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -155,17 +155,19 @@ def _test_distrib_all_reduce_group(device): def _test_distrib_all_gather(device): + rank = idist.get_rank() + res = torch.tensor(idist.all_gather(10), device=device) true_res = torch.tensor([10] * idist.get_world_size(), device=device) assert (res == true_res).all() - t = torch.tensor(idist.get_rank(), device=device) + t = torch.tensor(rank, device=device) res = idist.all_gather(t) true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device) assert (res == true_res).all() x = "test-test" - if idist.get_rank() == 0: + if rank == 0: x = "abc" res = idist.all_gather(x) true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1) @@ -173,14 +175,14 @@ def _test_distrib_all_gather(device): base_x = "tests/ignite/distributed/utils/test_native.py" * 2000 x = base_x - if idist.get_rank() == 0: + if rank == 0: x = "abc" res = idist.all_gather(x) true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1) assert res == true_res - t = torch.arange(100, device=device).reshape(4, 25) * (idist.get_rank() + 1) + t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1) in_dtype = t.dtype res = idist.all_gather(t) assert res.shape == (idist.get_world_size() * 4, 25) @@ -190,6 +192,14 @@ def _test_distrib_all_gather(device): true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1) assert (res == true_res).all() + t = torch.full((rank + 1, (rank + 1) * 2, idist.get_world_size() - rank), rank) + in_dtype = t.dtype + res = idist.all_gather(t, tensor_different_shape=True) + assert res[rank].shape == (rank + 1, (rank + 1) * 2, idist.get_world_size() - rank) + assert type(res) == list and res[0].dtype == in_dtype + for i in range(idist.get_world_size()): + assert (res[i] == torch.full((i + 1, (i + 1) * 2, idist.get_world_size() - i), i)).all() + if idist.get_world_size() > 1: with pytest.raises(TypeError, match=r"Unhandled input type"): idist.all_reduce([0, 1, 2]) @@ -218,7 +228,13 @@ def _test_distrib_all_gather_group(device): res = idist.all_gather(t, group=ranks) assert torch.equal(res, torch.tensor(ranks, device=device)) - ranks = "abc" + t = torch.tensor([rank], device=device) + if bnd not in ("horovod"): + res = idist.all_gather(t, group=ranks, tensor_different_shape=True) + if rank not in ranks: + assert res == [t] + else: + assert torch.equal(res[rank], torch.tensor(ranks, device=device)) if bnd in ("nccl", "gloo", "mpi"): with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"): @@ -307,7 +323,7 @@ def _test_distrib_new_group(device): if rank in ranks: assert g1.rank() == g2.rank() elif idist.has_xla_support and bnd in ("xla-tpu"): - assert idist.new_group(ranks) == [ranks] + assert idist.new_group(ranks) == ranks elif idist.has_hvd_support and bnd in ("horovod"): from horovod.common.process_sets import ProcessSet diff --git a/tests/ignite/metrics/test_mean_average_precision.py b/tests/ignite/metrics/test_mean_average_precision.py new file mode 100644 index 00000000000..4432d517a1c --- /dev/null +++ b/tests/ignite/metrics/test_mean_average_precision.py @@ -0,0 +1,500 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch +from sklearn.metrics import average_precision_score, precision_recall_curve +from sklearn.utils.extmath import stable_cumsum + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.metrics import MeanAveragePrecision +from ignite.utils import manual_seed, to_onehot + +manual_seed(41) + + +def test_wrong_input(): + with pytest.raises(ValueError, match="rec_thresholds should be a one-dimensional tensor or a sequence of floats"): + MeanAveragePrecision(rec_thresholds=torch.zeros((2, 2))) + + with pytest.raises(TypeError, match="rec_thresholds should be a sequence of floats or a tensor"): + MeanAveragePrecision(rec_thresholds={0, 0.2, 0.4, 0.6, 0.8}) + + with pytest.raises(ValueError, match="Wrong `average_operand` parameter"): + MeanAveragePrecision(average_operand=1) + + with pytest.raises(ValueError, match="Wrong `class_mean` parameter"): + MeanAveragePrecision(class_mean="samples") + + with pytest.raises(ValueError, match="rec_thresholds values should be between 0 and 1"): + MeanAveragePrecision(rec_thresholds=(0.0, 0.5, 1.0, 1.5)) + + with pytest.raises(ValueError, match="class_mean 'with_other_dims' is not compatible with classification"): + MeanAveragePrecision(class_mean="with_other_dims") + + +def test_wrong_classification_input(): + metric = MeanAveragePrecision() + assert metric._task == "classification" + + with pytest.raises(TypeError, match="`y_pred` should be a float tensor"): + metric.update((torch.tensor([0, 1, 0]), torch.tensor([1, 0, 1]))) + + metric = MeanAveragePrecision() + with pytest.warns(RuntimeWarning, match="`y` should be of dtype long when entry type is multiclass"): + metric.update((torch.tensor([[0.5, 0.4, 0.1]]), torch.tensor([2.0]))) + + with pytest.raises(ValueError, match="y_pred contains fewer classes than y"): + metric.update((torch.tensor([[0.5, 0.4, 0.1]]), torch.tensor([3]))) + + +class Dummy_mAP(MeanAveragePrecision): + def do_matching(self, pred: Tuple, target: Tuple): + return *pred, *target + + +def test_wrong_do_matching(): + metric = MeanAveragePrecision() + with pytest.raises(NotImplementedError, match="Please subclass MeanAveragePrecision and implement"): + metric.do_matching(None, None) + + metric = Dummy_mAP() + + with pytest.raises(ValueError, match="Returned TP, FP and scores dictionaries from do_matching should have"): + metric.update( + ( + ({1: torch.tensor([True])}, {1: torch.tensor([False])}), + ({1: 1}, {1: torch.tensor([0.8]), 2: torch.tensor([0.9])}), + ) + ) + + with pytest.raises(TypeError, match="Tensors in TP and FP dictionaries should be boolean or uint8"): + metric.update((({1: torch.tensor([1])}, {1: torch.tensor([False])}), ({1: 1}, {1: torch.tensor([0.8])}))) + + with pytest.raises( + ValueError, match="Sample dimension of tensors in TP, FP and scores should have equal size per class" + ): + metric.update( + (({1: torch.tensor([True])}, {1: torch.tensor([False, False])}), ({1: 1}, {1: torch.tensor([0.8])})) + ) + + metric.update((({1: torch.tensor([True])}, {1: torch.tensor([False])}), ({1: 1}, {1: torch.tensor([0.8])}))) + with pytest.raises(ValueError, match="Tensors in returned FP from do_matching should not change in shape except"): + metric.update( + ( + ({1: torch.tensor([False, True])}, {1: torch.tensor([[True, False], [False, False]])}), + ({1: 1}, {1: torch.tensor([0.8, 0.9])}), + ) + ) + + +def test__classification_prepare_output(): + metric = MeanAveragePrecision() + + metric._type = "binary" + scores, y = metric._classification_prepare_output( + torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool() + ) + assert scores.shape == y.shape == (1, 120) + + metric._type = "multiclass" + scores, y = metric._classification_prepare_output(torch.rand((5, 4, 3, 2)), torch.randint(0, 4, (5, 3, 2))) + assert scores.shape == (4, 30) and y.shape == (30,) + + metric._type = "multilabel" + scores, y = metric._classification_prepare_output( + torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool() + ) + assert scores.shape == y.shape == (4, 30) + + +def test_update(): + metric = MeanAveragePrecision() + assert len(metric._scores) == len(metric._P) == 0 + metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4)).bool())) + assert len(metric._scores) == len(metric._P) == 1 + + metric = Dummy_mAP() + assert len(metric._tp) == len(metric._fp) == len(metric._scores) == len(metric._P) == metric._num_classes == 0 + + metric.update((({1: torch.tensor([True])}, {1: torch.tensor([False])}), ({1: 1, 2: 1}, {1: torch.tensor([0.8])}))) + assert len(metric._tp[1]) == len(metric._fp[1]) == len(metric._scores[1]) == 1 + assert len(metric._P) == 2 and metric._P[2] == 1 + assert metric._num_classes == 3 + + metric.update((({}, {}), ({2: 2}, {}))) + assert metric._P[2] == 3 + + +def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score): + y_true = y_true == 1 + + desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] + y_true = y_true[desc_score_indices] + weight = 1.0 + + tps = stable_cumsum(y_true * weight) + fps = stable_cumsum((1 - y_true) * weight) + ps = tps + fps + precision = np.zeros_like(tps) + np.divide(tps, ps, out=precision, where=(ps != 0)) + if tps[-1] == 0: + recall = np.ones_like(tps) + else: + recall = tps / tps[-1] + + sl = slice(None, None, -1) + return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), None + + +@pytest.mark.parametrize( + "allow_multiple_recalls_at_single_threshold, sklearn_pr_rec_curve", + [ + (False, precision_recall_curve), + (True, sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold), + ], +) +def test__measure_recall_and_precision(allow_multiple_recalls_at_single_threshold, sklearn_pr_rec_curve): + # Classification + m = MeanAveragePrecision(allow_multiple_recalls_at_single_threshold=allow_multiple_recalls_at_single_threshold) + + scores = torch.rand((50,)) + y_true = torch.randint(0, 2, (50,)).bool() + precision, recall, _ = sklearn_pr_rec_curve(y_true.numpy(), scores.numpy()) + if allow_multiple_recalls_at_single_threshold: + y_true = y_true.unsqueeze(0) + scores = scores.unsqueeze(0) + FP = ~y_true if allow_multiple_recalls_at_single_threshold else None + P = y_true.sum(dim=-1) + ignite_recall, ignite_precision = m._measure_recall_and_precision(y_true, FP, scores, P) + assert (ignite_recall.squeeze().flip(0).numpy() == recall[:-1]).all() + assert (ignite_precision.squeeze().flip(0).numpy() == precision[:-1]).all() + + # Classification, when there's no actual positive. Numpy expectedly raises warning. + scores = torch.rand((50,)) + y_true = torch.zeros((50,)).bool() + precision, recall, _ = sklearn_pr_rec_curve(y_true.numpy(), scores.numpy()) + if allow_multiple_recalls_at_single_threshold: + y_true = y_true.unsqueeze(0) + scores = scores.unsqueeze(0) + FP = ~y_true if allow_multiple_recalls_at_single_threshold else None + P = torch.tensor([0]) if allow_multiple_recalls_at_single_threshold else torch.tensor(0) + ignite_recall, ignite_precision = m._measure_recall_and_precision(y_true, FP, scores, P) + assert (ignite_recall.flip(0).numpy() == recall[:-1]).all() + assert (ignite_precision.flip(0).numpy() == precision[:-1]).all() + + # Detection, in the case detector detects all gt objects but also produces some wrong predictions. + scores = torch.rand((50,)) + y_true = torch.randint(0, 2, (50,)) + m = Dummy_mAP(allow_multiple_recalls_at_single_threshold=allow_multiple_recalls_at_single_threshold) + + ignite_recall, ignite_precision = m._measure_recall_and_precision( + y_true.bool(), ~(y_true.bool()), scores, y_true.sum() + ) + sklearn_precision, sklearn_recall, _ = sklearn_pr_rec_curve(y_true.numpy(), scores.numpy()) + assert (ignite_recall.flip(0).numpy() == sklearn_recall[:-1]).all() + assert (ignite_precision.flip(0).numpy() == sklearn_precision[:-1]).all() + + # Detection like above but with two additional mean dimensions. + scores = torch.rand((50,)) + y_true = torch.zeros((6, 8, 50)) + sklearn_precisions, sklearn_recalls = [], [] + for i in range(6): + for j in range(8): + y_true[i, j, np.random.choice(50, size=15, replace=False)] = 1 + precision, recall, _ = sklearn_pr_rec_curve(y_true[i, j].numpy(), scores.numpy()) + sklearn_precisions.append(precision[:-1]) + sklearn_recalls.append(recall[:-1]) + sklearn_precisions = np.array(sklearn_precisions).reshape(6, 8, -1) + sklearn_recalls = np.array(sklearn_recalls).reshape(6, 8, -1) + ignite_recall, ignite_precision = m._measure_recall_and_precision( + y_true.bool(), ~(y_true.bool()), scores, torch.tensor(15) + ) + assert (ignite_recall.flip(-1).numpy() == sklearn_recalls).all() + assert (ignite_precision.flip(-1).numpy() == sklearn_precisions).all() + + +def test__measure_average_precision(): + m = MeanAveragePrecision() + + # Binary data + scores = np.random.rand(50) + y_true = np.random.randint(0, 2, 50) + ap = average_precision_score(y_true, scores) + precision, recall, _ = precision_recall_curve(y_true, scores) + ignite_ap = m._measure_average_precision( + torch.from_numpy(recall[:-1]).flip(-1), torch.from_numpy(precision[:-1]).flip(-1) + ) + assert np.allclose(ignite_ap.item(), ap) + + # Multilabel data + scores = np.random.rand(50, 5) + y_true = np.random.randint(0, 2, (50, 5)) + ap = average_precision_score(y_true, scores, average=None) + ignite_ap = [] + for cls in range(scores.shape[1]): + precision, recall, _ = precision_recall_curve(y_true[:, cls], scores[:, cls]) + ignite_ap.append( + m._measure_average_precision( + torch.from_numpy(recall[:-1]).flip(-1), torch.from_numpy(precision[:-1]).flip(-1) + ).item() + ) + ignite_ap = np.array(ignite_ap) + assert np.allclose(ignite_ap, ap) + + +def test_compute_classification_binary_data(): + m = MeanAveragePrecision() + scores = torch.rand((130,)) + y_true = torch.randint(0, 2, (130,)) + + m.update((scores[:50], y_true[:50])) + m.update((scores[50:], y_true[50:])) + ignite_map = m.compute() + + map = average_precision_score(y_true.numpy(), scores.numpy()) + + assert np.allclose(ignite_map, map) + + +@pytest.mark.parametrize("class_mean", [None, "macro", "micro", "weighted"]) +def test_compute_classification_nonbinary_data(class_mean): + scores = torch.rand((130, 5, 2, 2)) + sklearn_scores = scores.transpose(1, -1).reshape(-1, 5).numpy() + + # Multiclass + m = MeanAveragePrecision(class_mean=class_mean) + y_true = torch.randint(0, 5, (130, 2, 2)) + m.update((scores[:50], y_true[:50])) + m.update((scores[50:], y_true[50:])) + ignite_map = m.compute().numpy() + + y_true = to_onehot(y_true, 5).transpose(1, -1).reshape(-1, 5).numpy() + sklearn_map = average_precision_score(y_true, sklearn_scores, average=class_mean) + + assert np.allclose(sklearn_map, ignite_map) + + # Multilabel + m = MeanAveragePrecision(classification_is_multilabel=True, class_mean=class_mean) + y_true = torch.randint(0, 2, (130, 5, 2, 2)).bool() + m.update((scores[:50], y_true[:50])) + m.update((scores[50:], y_true[50:])) + ignite_map = m.compute().numpy() + + y_true = y_true.transpose(1, -1).reshape(-1, 5).numpy() + sklearn_map = average_precision_score(y_true, sklearn_scores, average=class_mean) + + assert np.allclose(sklearn_map, ignite_map) + + +@pytest.mark.parametrize("class_mean", ["macro", None, "micro", "weighted", "with_other_dims"]) +def test_compute_detection(class_mean): + m = Dummy_mAP(class_mean=class_mean) + + # The case in which, detector detects all gt objects but also produces some wrong predictions. Also classes + # have the same number of predictions. + + y_true = torch.randint(0, 2, (40, 5)) + scores = torch.rand((40, 5)) + + for s in [slice(20), slice(20, 40)]: + tp = {c: y_true[s, c].bool() for c in range(5)} + fp = {c: ~(y_true[s, c].bool()) for c in range(5)} + p = dict(enumerate(y_true[s].sum(dim=0).tolist())) + score = {c: scores[s, c] for c in range(5)} + m.update(((tp, fp), (p, score))) + + ignite_map = m.compute().numpy() + + sklearn_class_mean = class_mean if class_mean != "with_other_dims" else "macro" + sklearn_map = average_precision_score(y_true.numpy(), scores.numpy(), average=sklearn_class_mean) + assert np.allclose(sklearn_map, ignite_map) + + # Like above but with two additional mean dimensions. + m.reset() + y_true = torch.zeros((5, 6, 8, 50)) + scores = torch.rand((50, 5)) + P_counts = np.random.choice(50, size=5) + sklearn_aps = [] + for c in range(5): + for i in range(6): + for j in range(8): + y_true[c, i, j, np.random.choice(50, size=P_counts[c], replace=False)] = 1 + if class_mean != "micro": + sklearn_aps.append( + average_precision_score( + y_true[c].view(6 * 8, 50).T.numpy(), scores[:, c].repeat(6 * 8, 1).T.numpy(), average=None + ) + ) + if class_mean == "micro": + sklearn_aps = average_precision_score( + torch.cat(y_true.unbind(0), dim=-1).view(6 * 8, 5 * 50).T.numpy(), + scores.T.reshape(5 * 50).repeat(6 * 8, 1).T.numpy(), + average=None, + ) + sklearn_aps = np.array(sklearn_aps) + if class_mean in (None, "micro"): + sklearn_map = sklearn_aps.mean(axis=-1) + elif class_mean == "macro": + sklearn_map = sklearn_aps.mean(axis=-1)[P_counts != 0].mean() + elif class_mean == "with_other_dims": + sklearn_map = sklearn_aps[P_counts != 0].mean() + else: + sklearn_map = np.dot(P_counts, sklearn_aps.mean(axis=-1)) / P_counts.sum() + + for s in [slice(0, 20), slice(20, 50)]: + tp = {c: y_true[c, :, :, s].bool() for c in range(5)} + fp = {c: ~(y_true[c, :, :, s].bool()) for c in range(5)} + p = dict(enumerate(y_true[:, 0, 0, s].sum(dim=-1).tolist())) + score = {c: scores[s, c] for c in range(5)} + m.update(((tp, fp), (p, score))) + ignite_map = m.compute().numpy() + assert np.allclose(ignite_map, sklearn_map) + + +@pytest.mark.parametrize("data_type", ["binary", "multiclass", "multilabel"]) +def test_distrib_integration_classification(distributed, data_type): + rank = idist.get_rank() + world_size = idist.get_world_size() + device = idist.device() + + def _test(metric_device): + def update(_, i): + return ( + y_preds[(2 * rank + i) * 10 : (2 * rank + i + 1) * 10], + y_true[(2 * rank + i) * 10 : (2 * rank + i + 1) * 10], + ) + + engine = Engine(update) + mAP = MeanAveragePrecision(classification_is_multilabel=data_type == "multilabel", device=metric_device) + mAP.attach(engine, "mAP") + + y_true_size = (10 * 2 * world_size, 3, 2) if data_type != "multilabel" else (10 * 2 * world_size, 4, 3, 2) + y_true = torch.randint(0, 4 if data_type == "multiclass" else 2, size=y_true_size).to(device) + y_preds_size = (10 * 2 * world_size, 4, 3, 2) if data_type != "binary" else (10 * 2 * world_size, 3, 2) + y_preds = torch.rand(y_preds_size).to(device) + + engine.run(range(2), max_epochs=1) + assert "mAP" in engine.state.metrics + + if data_type == "multiclass": + y_true = to_onehot(y_true, 4) + + if data_type == "binary": + y_true = y_true.view(-1) + y_preds = y_preds.view(-1) + else: + y_true = y_true.transpose(1, -1).reshape(-1, 4) + y_preds = y_preds.transpose(1, -1).reshape(-1, 4) + + sklearn_mAP = average_precision_score(y_true.numpy(), y_preds.numpy()) + assert np.allclose(sklearn_mAP, engine.state.metrics["mAP"]) + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(metric_device) + + +@pytest.mark.parametrize("class_mean", [None, "macro", "micro", "weighted", "with_other_dims"]) +def test_distrib_integration_detection(distributed, class_mean): + rank = idist.get_rank() + device = idist.device() + world_size = idist.get_world_size() + + def _test(metric_device): + def update(_, i): + y_true_batch = y_true[..., (2 * rank + i) * 10 : (2 * rank + i + 1) * 10] + scores_batch = scores[..., (2 * rank + i) * 10 : (2 * rank + i + 1) * 10] + return ( + ({c: y_true_batch[c].bool() for c in range(4)}, {c: ~(y_true_batch[c].bool()) for c in range(4)}), + ( + dict( + enumerate( + (y_true_batch[:, 0, 0] if y_true_batch.ndim == 4 else y_true_batch).sum(dim=-1).tolist() + ) + ), + {c: scores_batch[c] for c in range(4)}, + ), + ) + + engine = Engine(update) + # The case in which, detector detects all gt objects but also produces some wrong predictions. Also classes + # have the same number of predictions. + mAP = Dummy_mAP(device=metric_device, class_mean=class_mean) + mAP.attach(engine, "mAP") + + y_true = torch.randint(0, 2, size=(4, 10 * 2 * world_size)).to(device) + scores = torch.rand((4, 10 * 2 * world_size)).to(device) + engine.run(range(2), max_epochs=1) + assert "mAP" in engine.state.metrics + sklearn_class_mean = class_mean if class_mean != "with_other_dims" else "macro" + sklearn_map = average_precision_score(y_true.T.numpy(), scores.T.numpy(), average=sklearn_class_mean) + assert np.allclose(sklearn_map, engine.state.metrics["mAP"]) + + # Like above but with two additional mean dimensions. + y_true = torch.zeros((4, 6, 8, 10 * 2 * world_size)) + + P_counts = np.random.choice(10 * 2 * world_size, size=4) + sklearn_aps = [] + for c in range(4): + for i in range(6): + for j in range(8): + y_true[c, i, j, np.random.choice(10 * 2 * world_size, size=P_counts[c], replace=False)] = 1 + if class_mean != "micro": + sklearn_aps.append( + average_precision_score( + y_true[c].view(6 * 8, 10 * 2 * world_size).T.numpy(), + scores[c].repeat(6 * 8, 1).T.numpy(), + average=None, + ) + ) + if class_mean == "micro": + sklearn_aps = average_precision_score( + torch.cat(y_true.unbind(0), dim=-1).view(6 * 8, 4 * 10 * 2 * world_size).T.numpy(), + scores.reshape(4 * 10 * 2 * world_size).repeat(6 * 8, 1).T.numpy(), + average=None, + ) + sklearn_aps = np.array(sklearn_aps) + if class_mean in (None, "micro"): + sklearn_map = sklearn_aps.mean(axis=-1) + elif class_mean == "macro": + sklearn_map = sklearn_aps.mean(axis=-1)[P_counts != 0].mean() + elif class_mean == "with_other_dims": + sklearn_map = sklearn_aps[P_counts != 0].mean() + else: + sklearn_map = np.dot(P_counts, sklearn_aps.mean(axis=-1)) / P_counts.sum() + + engine.run(range(2), max_epochs=1) + + assert np.allclose(sklearn_map, engine.state.metrics["mAP"]) + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(metric_device) + + +# class MatchFirstDetectionFirst_mAP(MeanAveragePrecision): +# def do_matching(self, pred: Tuple[Sequence[int], Sequence[float]] , target: Sequence[int]): +# P = dict(Counter(target)) +# tp = defaultdict(lambda: []) +# scores = defaultdict(lambda: []) + +# target = torch.tensor(target) +# matched = torch.zeros((len(target),)).bool() +# for label, score in zip(*pred): +# try: +# matched[torch.logical_and(target == label, ~matched).tolist().index(True)] = True +# tp[label].append(True) +# except ValueError: +# tp[label].append(False) +# scores[label].append(score) + +# tp = {label: torch.tensor(_tp) for label, _tp in tp.items()} +# fp = {label: ~_tp for label, _tp in tp.items()} +# scores = {label: torch.tensor(_scores) for label, _scores in scores.items()} +# return tp, fp, P, scores diff --git a/tests/ignite/metrics/vision/__init__.py b/tests/ignite/metrics/vision/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py new file mode 100644 index 00000000000..ae0b75da9a5 --- /dev/null +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -0,0 +1,884 @@ +import sys +from collections import namedtuple +from math import ceil +from typing import Dict, List, Tuple +from unittest.mock import patch + +import numpy as np + +np.float = float + +import pytest +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from torch.distributions.geometric import Geometric + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.metrics import ObjectDetectionMAP +from ignite.utils import manual_seed + +torch.set_printoptions(linewidth=200) +manual_seed(12) +np.set_printoptions(linewidth=200) + + +def coco_val2017_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]: + """ + Predictions are done using torchvision's `fasterrcnn_resnet50_fpn_v2` + with the following snippet. Note that beforehand, COCO images and annotations + were downloaded and unzipped into "val2017" and "annotations" folders respectively. + + .. code-block:: python + + import torch + from torchvision.models import detection as dtv + from torchvision.datasets import CocoDetection + + coco = CocoDetection( + "val2017", + "annotations/instances_val2017.json", + transform=dtv.FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1.transforms() + ) + model = dtv.fasterrcnn_resnet50_fpn_v2( + weights=dtv.FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1 + ) + model.eval() + + sample = torch.randint(len(coco), (10,)).tolist() + with torch.no_grad(): + pred = model([coco[s][0] for s in sample]) + + pred = [torch.cat(( + p['boxes'].int(), + p['scores'].reshape(-1, 1), + p['labels'].reshape(-1,1) + ), dim=1) for p in pred] + """ + gt = [ + torch.tensor( + [ + [418.1300, 261.8500, 511.4100, 311.0500, 23.0000, 0.0000], + [269.9200, 256.9700, 311.1200, 283.3000, 23.0000, 0.0000], + [175.1900, 252.0200, 209.2000, 278.7200, 23.0000, 0.0000], + ] + ), + torch.tensor([[196.7700, 199.6900, 301.8300, 373.2700, 2.0000, 0.0000]]), + torch.tensor( + [ + [88.4500, 168.2700, 465.9800, 318.2000, 6.0000, 0.0000], + [2.2900, 224.2100, 31.3000, 282.2200, 1.0000, 0.0000], + [143.9700, 240.8300, 176.8100, 268.6500, 84.0000, 0.0000], + [234.6200, 202.0900, 264.6500, 244.5800, 84.0000, 0.0000], + [303.0400, 193.1300, 339.9900, 242.5600, 84.0000, 0.0000], + [358.0100, 195.3100, 408.9800, 261.2500, 84.0000, 0.0000], + [197.2100, 236.3500, 233.7700, 275.1300, 84.0000, 0.0000], + ] + ), + torch.tensor( + [ + [26.1900, 44.8000, 435.5700, 268.0900, 50.0000, 0.0000], + [190.5400, 244.9400, 322.0100, 309.6500, 53.0000, 0.0000], + [72.8900, 27.2800, 591.3700, 417.8600, 51.0000, 0.0000], + [22.5400, 402.4500, 253.7400, 596.7200, 51.0000, 0.0000], + [441.9100, 168.5000, 494.3700, 213.0100, 53.0000, 0.0000], + [340.0700, 330.6800, 464.6200, 373.5800, 53.0000, 0.0000], + [28.6100, 247.9800, 591.3400, 597.6900, 67.0000, 0.0000], + [313.9300, 108.4000, 441.3300, 234.2600, 57.0000, 0.0000], + ] + ), + torch.tensor([[0.9600, 74.6700, 498.7600, 261.3400, 5.0000, 0.0000]]), + torch.tensor( + [ + [1.3800, 2.7500, 361.9800, 499.6100, 17.0000, 0.0000], + [325.4600, 239.8200, 454.2700, 416.8700, 47.0000, 0.0000], + [165.3900, 307.3400, 332.2200, 548.9500, 47.0000, 0.0000], + [424.0600, 179.6600, 480.0000, 348.0600, 47.0000, 0.0000], + ] + ), + torch.tensor( + [ + [218.2600, 84.2700, 473.6000, 230.0600, 28.0000, 0.0000], + [212.6300, 220.0000, 417.2200, 312.8100, 62.0000, 0.0000], + [104.4600, 311.2200, 222.4600, 375.0000, 62.0000, 0.0000], + ] + ), + torch.tensor( + [ + [144.4400, 132.2200, 639.7700, 253.8800, 5.0000, 0.0000], + [0.0000, 48.4300, 492.8900, 188.0200, 5.0000, 0.0000], + ] + ), + torch.tensor( + [ + [205.9900, 276.6200, 216.8500, 311.2300, 1.0000, 0.0000], + [378.9100, 65.4500, 476.1900, 185.3100, 38.0000, 0.0000], + ] + ), + torch.tensor( + [ + [392.0900, 110.3800, 417.5200, 136.7000, 85.0000, 0.0000], + [594.9000, 223.4600, 640.0000, 473.6600, 1.0000, 0.0000], + [381.1400, 225.8200, 443.3100, 253.1200, 58.0000, 0.0000], + [175.2800, 253.8000, 221.0600, 281.7400, 60.0000, 0.0000], + [96.4100, 357.6800, 168.8800, 392.0300, 60.0000, 0.0000], + [126.6500, 258.3200, 178.3800, 285.8900, 60.0000, 0.0000], + [71.0200, 247.1500, 117.2300, 275.0500, 60.0000, 0.0000], + [34.2300, 265.8500, 82.4000, 296.0700, 60.0000, 0.0000], + [32.9700, 252.5500, 69.1700, 271.6400, 60.0000, 0.0000], + [50.7800, 229.3100, 90.7400, 248.5200, 60.0000, 0.0000], + [82.8900, 218.2800, 126.5900, 237.6700, 60.0000, 0.0000], + [80.6300, 263.9200, 128.1000, 290.4100, 60.0000, 0.0000], + [277.2200, 222.9500, 323.7000, 243.6300, 60.0000, 0.0000], + [225.5200, 155.3300, 272.0700, 167.2700, 60.0000, 0.0000], + [360.6200, 178.1900, 444.4800, 207.2900, 1.0000, 0.0000], + [171.6000, 149.8100, 219.9900, 169.7400, 60.0000, 0.0000], + [398.0100, 223.2000, 461.1000, 252.6500, 58.0000, 0.0000], + [1.0000, 123.0000, 603.0000, 402.0000, 60.0000, 1.0000], + ] + ), + ] + + pred = [ + torch.tensor( + [ + [418.0000, 260.0000, 513.0000, 310.0000, 0.9975, 23.0000], + [175.0000, 252.0000, 208.0000, 280.0000, 0.9631, 20.0000], + [269.0000, 256.0000, 310.0000, 292.0000, 0.6253, 20.0000], + [269.0000, 256.0000, 311.0000, 286.0000, 0.3523, 21.0000], + [269.0000, 256.0000, 311.0000, 288.0000, 0.2267, 16.0000], + [175.0000, 252.0000, 209.0000, 280.0000, 0.1528, 23.0000], + ] + ), + torch.tensor( + [ + [197.0000, 207.0000, 299.0000, 344.0000, 0.9573, 2.0000], + [0.0000, 0.0000, 371.0000, 477.0000, 0.8672, 7.0000], + [218.0000, 250.0000, 298.0000, 369.0000, 0.6232, 2.0000], + [303.0000, 70.0000, 374.0000, 471.0000, 0.2851, 82.0000], + [200.0000, 203.0000, 282.0000, 290.0000, 0.1519, 2.0000], + [206.0000, 169.0000, 275.0000, 236.0000, 0.0581, 72.0000], + ] + ), + torch.tensor( + [ + [3.0000, 224.0000, 31.0000, 280.0000, 0.9978, 1.0000], + [86.0000, 174.0000, 457.0000, 312.0000, 0.9917, 6.0000], + [90.0000, 176.0000, 460.0000, 313.0000, 0.2907, 8.0000], + [443.0000, 198.0000, 564.0000, 255.0000, 0.2400, 3.0000], + [452.0000, 197.0000, 510.0000, 220.0000, 0.1778, 8.0000], + [462.0000, 197.0000, 513.0000, 220.0000, 0.1750, 3.0000], + [489.0000, 196.0000, 522.0000, 212.0000, 0.1480, 3.0000], + [257.0000, 165.0000, 313.0000, 182.0000, 0.1294, 8.0000], + [438.0000, 195.0000, 556.0000, 256.0000, 0.1198, 8.0000], + [555.0000, 235.0000, 575.0000, 251.0000, 0.0831, 3.0000], + [486.0000, 54.0000, 504.0000, 63.0000, 0.0634, 34.0000], + [565.0000, 245.0000, 573.0000, 251.0000, 0.0605, 3.0000], + [257.0000, 165.0000, 313.0000, 183.0000, 0.0569, 6.0000], + [0.0000, 233.0000, 42.0000, 256.0000, 0.0526, 3.0000], + [26.0000, 237.0000, 42.0000, 250.0000, 0.0515, 28.0000], + ] + ), + torch.tensor( + [ + [62.0000, 25.0000, 596.0000, 406.0000, 0.9789, 51.0000], + [24.0000, 393.0000, 250.0000, 596.0000, 0.9496, 51.0000], + [25.0000, 154.0000, 593.0000, 595.0000, 0.8013, 67.0000], + [27.0000, 112.0000, 246.0000, 263.0000, 0.7848, 50.0000], + [186.0000, 242.0000, 362.0000, 359.0000, 0.7522, 54.0000], + [448.0000, 209.0000, 512.0000, 241.0000, 0.7446, 57.0000], + [197.0000, 188.0000, 297.0000, 235.0000, 0.6616, 57.0000], + [233.0000, 210.0000, 297.0000, 229.0000, 0.5214, 57.0000], + [519.0000, 255.0000, 589.0000, 284.0000, 0.4844, 48.0000], + [316.0000, 59.0000, 360.0000, 132.0000, 0.4211, 52.0000], + [27.0000, 104.0000, 286.0000, 267.0000, 0.3727, 48.0000], + [191.0000, 89.0000, 344.0000, 256.0000, 0.3143, 57.0000], + [182.0000, 80.0000, 529.0000, 353.0000, 0.3046, 57.0000], + [417.0000, 266.0000, 488.0000, 333.0000, 0.2602, 52.0000], + [324.0000, 110.0000, 369.0000, 157.0000, 0.2550, 57.0000], + [314.0000, 61.0000, 361.0000, 135.0000, 0.2517, 53.0000], + [252.0000, 218.0000, 336.0000, 249.0000, 0.2486, 57.0000], + [191.0000, 241.0000, 342.0000, 354.0000, 0.2384, 53.0000], + [194.0000, 174.0000, 327.0000, 247.0000, 0.2121, 57.0000], + [229.0000, 200.0000, 302.0000, 233.0000, 0.2030, 57.0000], + [439.0000, 192.0000, 526.0000, 252.0000, 0.2004, 57.0000], + [203.0000, 144.0000, 523.0000, 357.0000, 0.1937, 52.0000], + [17.0000, 90.0000, 361.0000, 283.0000, 0.1875, 50.0000], + [15.0000, 14.0000, 598.0000, 272.0000, 0.1747, 67.0000], + [319.0000, 63.0000, 433.0000, 158.0000, 0.1621, 53.0000], + [319.0000, 62.0000, 434.0000, 157.0000, 0.1602, 52.0000], + [253.0000, 85.0000, 311.0000, 147.0000, 0.1562, 57.0000], + [14.0000, 25.0000, 214.0000, 211.0000, 0.1330, 67.0000], + [147.0000, 146.0000, 545.0000, 386.0000, 0.0867, 51.0000], + [324.0000, 174.0000, 455.0000, 292.0000, 0.0761, 52.0000], + [25.0000, 480.0000, 205.0000, 594.0000, 0.0727, 59.0000], + [166.0000, 0.0000, 603.0000, 32.0000, 0.0583, 84.0000], + [519.0000, 255.0000, 589.0000, 285.0000, 0.0578, 50.0000], + ] + ), + torch.tensor( + [ + [0.0000, 58.0000, 495.0000, 258.0000, 0.9917, 5.0000], + [199.0000, 291.0000, 212.0000, 299.0000, 0.5247, 37.0000], + [0.0000, 277.0000, 307.0000, 331.0000, 0.1169, 5.0000], + [0.0000, 284.0000, 302.0000, 308.0000, 0.0984, 5.0000], + [348.0000, 231.0000, 367.0000, 244.0000, 0.0621, 15.0000], + [349.0000, 229.0000, 367.0000, 244.0000, 0.0547, 8.0000], + ] + ), + torch.tensor( + [ + [1.0000, 9.0000, 365.0000, 506.0000, 0.9980, 17.0000], + [170.0000, 304.0000, 335.0000, 542.0000, 0.9867, 47.0000], + [422.0000, 179.0000, 480.0000, 351.0000, 0.9476, 47.0000], + [329.0000, 241.0000, 449.0000, 420.0000, 0.8503, 47.0000], + [0.0000, 352.0000, 141.0000, 635.0000, 0.4145, 74.0000], + [73.0000, 277.0000, 478.0000, 628.0000, 0.3859, 67.0000], + [329.0000, 183.0000, 373.0000, 286.0000, 0.3097, 47.0000], + [0.0000, 345.0000, 145.0000, 631.0000, 0.2359, 9.0000], + [1.0000, 341.0000, 147.0000, 632.0000, 0.2259, 70.0000], + [0.0000, 338.0000, 148.0000, 632.0000, 0.1669, 62.0000], + [339.0000, 154.0000, 410.0000, 248.0000, 0.1474, 47.0000], + [422.0000, 176.0000, 479.0000, 359.0000, 0.1422, 44.0000], + [0.0000, 349.0000, 148.0000, 636.0000, 0.1369, 42.0000], + [1.0000, 347.0000, 149.0000, 633.0000, 0.1118, 1.0000], + [324.0000, 238.0000, 455.0000, 423.0000, 0.0948, 86.0000], + [0.0000, 348.0000, 146.0000, 640.0000, 0.0885, 37.0000], + [0.0000, 342.0000, 140.0000, 626.0000, 0.0812, 81.0000], + [146.0000, 0.0000, 478.0000, 217.0000, 0.0812, 62.0000], + [75.0000, 102.0000, 357.0000, 553.0000, 0.0618, 64.0000], + [2.0000, 356.0000, 145.0000, 635.0000, 0.0608, 51.0000], + [0.0000, 337.0000, 149.0000, 637.0000, 0.0544, 3.0000], + ] + ), + torch.tensor( + [ + [212.0000, 219.0000, 418.0000, 312.0000, 0.9968, 62.0000], + [218.0000, 83.0000, 477.0000, 228.0000, 0.9902, 28.0000], + [113.0000, 221.0000, 476.0000, 368.0000, 0.3940, 62.0000], + [108.0000, 309.0000, 222.0000, 371.0000, 0.2972, 62.0000], + [199.0000, 124.0000, 206.0000, 130.0000, 0.2770, 16.0000], + [213.0000, 154.0000, 447.0000, 301.0000, 0.2698, 28.0000], + [122.0000, 297.0000, 492.0000, 371.0000, 0.2263, 62.0000], + [111.0000, 302.0000, 500.0000, 368.0000, 0.2115, 67.0000], + [319.0000, 220.0000, 424.0000, 307.0000, 0.1761, 62.0000], + [453.0000, 0.0000, 462.0000, 8.0000, 0.1390, 38.0000], + [107.0000, 309.0000, 222.0000, 371.0000, 0.1075, 15.0000], + [109.0000, 309.0000, 225.0000, 372.0000, 0.1028, 67.0000], + [137.0000, 301.0000, 499.0000, 371.0000, 0.0945, 61.0000], + [454.0000, 0.0000, 460.0000, 6.0000, 0.0891, 16.0000], + [162.0000, 102.0000, 167.0000, 105.0000, 0.0851, 16.0000], + [395.0000, 263.0000, 500.0000, 304.0000, 0.0813, 15.0000], + [107.0000, 298.0000, 491.0000, 373.0000, 0.0727, 9.0000], + [157.0000, 78.0000, 488.0000, 332.0000, 0.0573, 28.0000], + [110.0000, 282.0000, 500.0000, 369.0000, 0.0554, 15.0000], + [377.0000, 263.0000, 500.0000, 315.0000, 0.0527, 62.0000], + ] + ), + torch.tensor( + [ + [1.0000, 48.0000, 505.0000, 184.0000, 0.9939, 5.0000], + [152.0000, 60.0000, 633.0000, 255.0000, 0.9552, 5.0000], + [0.0000, 183.0000, 20.0000, 200.0000, 0.2347, 8.0000], + [0.0000, 185.0000, 7.0000, 202.0000, 0.1005, 8.0000], + [397.0000, 255.0000, 491.0000, 276.0000, 0.0781, 42.0000], + [0.0000, 186.0000, 7.0000, 202.0000, 0.0748, 3.0000], + [259.0000, 154.0000, 640.0000, 254.0000, 0.0630, 5.0000], + ] + ), + torch.tensor( + [ + [203.0000, 277.0000, 215.0000, 312.0000, 0.9953, 1.0000], + [380.0000, 70.0000, 475.0000, 183.0000, 0.9555, 38.0000], + [439.0000, 70.0000, 471.0000, 176.0000, 0.3617, 38.0000], + [379.0000, 143.0000, 390.0000, 158.0000, 0.2418, 38.0000], + [378.0000, 140.0000, 461.0000, 184.0000, 0.1672, 38.0000], + [226.0000, 252.0000, 230.0000, 255.0000, 0.0570, 16.0000], + ] + ), + torch.tensor( + [ + [597.0000, 216.0000, 639.0000, 475.0000, 0.9783, 1.0000], + [80.0000, 263.0000, 128.0000, 291.0000, 0.9571, 60.0000], + [126.0000, 258.0000, 178.0000, 286.0000, 0.9540, 60.0000], + [174.0000, 252.0000, 221.0000, 279.0000, 0.9434, 60.0000], + [248.0000, 323.0000, 300.0000, 354.0000, 0.9359, 60.0000], + [171.0000, 150.0000, 220.0000, 166.0000, 0.9347, 60.0000], + [121.0000, 151.0000, 173.0000, 168.0000, 0.9336, 60.0000], + [394.0000, 111.0000, 417.0000, 135.0000, 0.9256, 85.0000], + [300.0000, 327.0000, 362.0000, 358.0000, 0.9058, 60.0000], + [264.0000, 149.0000, 306.0000, 166.0000, 0.8948, 60.0000], + [306.0000, 150.0000, 350.0000, 165.0000, 0.8798, 60.0000], + [70.0000, 150.0000, 127.0000, 168.0000, 0.8697, 60.0000], + [110.0000, 138.0000, 153.0000, 156.0000, 0.8586, 60.0000], + [223.0000, 154.0000, 270.0000, 166.0000, 0.8576, 60.0000], + [541.0000, 81.0000, 602.0000, 153.0000, 0.8352, 79.0000], + [34.0000, 266.0000, 82.0000, 295.0000, 0.8326, 60.0000], + [444.0000, 302.0000, 484.0000, 325.0000, 0.7900, 60.0000], + [14.0000, 152.0000, 73.0000, 169.0000, 0.7792, 60.0000], + [115.0000, 247.0000, 157.0000, 268.0000, 0.7654, 60.0000], + [168.0000, 350.0000, 237.0000, 385.0000, 0.7241, 60.0000], + [197.0000, 319.0000, 249.0000, 351.0000, 0.7062, 60.0000], + [89.0000, 331.0000, 149.0000, 366.0000, 0.6970, 61.0000], + [66.0000, 143.0000, 109.0000, 153.0000, 0.6787, 60.0000], + [152.0000, 332.0000, 217.0000, 358.0000, 0.6739, 60.0000], + [99.0000, 355.0000, 169.0000, 395.0000, 0.6582, 60.0000], + [583.0000, 205.0000, 594.0000, 218.0000, 0.6428, 47.0000], + [498.0000, 301.0000, 528.0000, 321.0000, 0.6373, 60.0000], + [255.0000, 146.0000, 274.0000, 155.0000, 0.6366, 60.0000], + [148.0000, 231.0000, 192.0000, 250.0000, 0.5984, 60.0000], + [501.0000, 140.0000, 551.0000, 164.0000, 0.5910, 60.0000], + [156.0000, 144.0000, 193.0000, 157.0000, 0.5910, 60.0000], + [381.0000, 225.0000, 444.0000, 254.0000, 0.5737, 60.0000], + [156.0000, 243.0000, 206.0000, 264.0000, 0.5675, 60.0000], + [229.0000, 302.0000, 280.0000, 331.0000, 0.5588, 60.0000], + [492.0000, 134.0000, 516.0000, 142.0000, 0.5492, 60.0000], + [346.0000, 150.0000, 383.0000, 165.0000, 0.5481, 60.0000], + [17.0000, 143.0000, 67.0000, 154.0000, 0.5254, 60.0000], + [283.0000, 308.0000, 330.0000, 334.0000, 0.5141, 60.0000], + [421.0000, 222.0000, 489.0000, 250.0000, 0.4983, 60.0000], + [0.0000, 107.0000, 51.0000, 134.0000, 0.4978, 78.0000], + [70.0000, 248.0000, 113.0000, 270.0000, 0.4884, 60.0000], + [215.0000, 147.0000, 262.0000, 164.0000, 0.4867, 60.0000], + [293.0000, 145.0000, 315.0000, 157.0000, 0.4841, 60.0000], + [523.0000, 272.0000, 548.0000, 288.0000, 0.4728, 60.0000], + [534.0000, 152.0000, 560.0000, 164.0000, 0.4644, 60.0000], + [516.0000, 294.0000, 546.0000, 314.0000, 0.4597, 60.0000], + [352.0000, 319.0000, 395.0000, 342.0000, 0.4364, 60.0000], + [106.0000, 234.0000, 149.0000, 255.0000, 0.4317, 60.0000], + [326.0000, 136.0000, 357.0000, 147.0000, 0.4281, 60.0000], + [135.0000, 132.0000, 166.0000, 145.0000, 0.4159, 60.0000], + [63.0000, 238.0000, 104.0000, 259.0000, 0.4136, 60.0000], + [472.0000, 221.0000, 527.0000, 246.0000, 0.4090, 60.0000], + [189.0000, 137.0000, 225.0000, 154.0000, 0.4018, 60.0000], + [135.0000, 311.0000, 195.0000, 337.0000, 0.3965, 60.0000], + [9.0000, 148.0000, 68.0000, 164.0000, 0.3915, 60.0000], + [366.0000, 232.0000, 408.0000, 257.0000, 0.3858, 60.0000], + [291.0000, 243.0000, 318.0000, 266.0000, 0.3838, 60.0000], + [494.0000, 276.0000, 524.0000, 300.0000, 0.3727, 60.0000], + [97.0000, 135.0000, 122.0000, 148.0000, 0.3717, 60.0000], + [467.0000, 289.0000, 499.0000, 309.0000, 0.3710, 60.0000], + [150.0000, 134.0000, 188.0000, 146.0000, 0.3705, 60.0000], + [427.0000, 290.0000, 463.0000, 314.0000, 0.3575, 60.0000], + [38.0000, 343.0000, 101.0000, 408.0000, 0.3540, 61.0000], + [76.0000, 313.0000, 128.0000, 343.0000, 0.3429, 61.0000], + [507.0000, 146.0000, 537.0000, 163.0000, 0.3420, 60.0000], + [451.0000, 268.0000, 478.0000, 282.0000, 0.3389, 60.0000], + [545.0000, 292.0000, 578.0000, 314.0000, 0.3252, 60.0000], + [350.0000, 309.0000, 393.0000, 336.0000, 0.3246, 60.0000], + [388.0000, 307.0000, 429.0000, 337.0000, 0.3240, 60.0000], + [34.0000, 253.0000, 67.0000, 270.0000, 0.3228, 60.0000], + [402.0000, 224.0000, 462.0000, 252.0000, 0.3177, 60.0000], + [160.0000, 131.0000, 191.0000, 142.0000, 0.3104, 60.0000], + [132.0000, 310.0000, 197.0000, 340.0000, 0.2923, 61.0000], + [481.0000, 84.0000, 543.0000, 140.0000, 0.2872, 79.0000], + [13.0000, 137.0000, 62.0000, 153.0000, 0.2859, 60.0000], + [98.0000, 355.0000, 171.0000, 395.0000, 0.2843, 61.0000], + [115.0000, 149.0000, 156.0000, 160.0000, 0.2774, 60.0000], + [65.0000, 137.0000, 101.0000, 148.0000, 0.2732, 60.0000], + [314.0000, 242.0000, 341.0000, 264.0000, 0.2714, 60.0000], + [455.0000, 237.0000, 486.0000, 251.0000, 0.2630, 60.0000], + [552.0000, 146.0000, 595.0000, 164.0000, 0.2553, 60.0000], + [50.0000, 133.0000, 78.0000, 145.0000, 0.2485, 60.0000], + [544.0000, 280.0000, 570.0000, 294.0000, 0.2459, 60.0000], + [40.0000, 144.0000, 66.0000, 154.0000, 0.2453, 60.0000], + [289.0000, 254.0000, 312.0000, 268.0000, 0.2374, 60.0000], + [266.0000, 140.0000, 292.0000, 149.0000, 0.2357, 60.0000], + [504.0000, 266.0000, 525.0000, 277.0000, 0.2281, 60.0000], + [304.0000, 285.0000, 346.0000, 309.0000, 0.2256, 60.0000], + [303.0000, 222.0000, 341.0000, 238.0000, 0.2236, 60.0000], + [498.0000, 219.0000, 549.0000, 243.0000, 0.2168, 60.0000], + [89.0000, 333.0000, 144.0000, 352.0000, 0.2159, 61.0000], + [0.0000, 108.0000, 51.0000, 135.0000, 0.2076, 79.0000], + [303.0000, 220.0000, 329.0000, 231.0000, 0.2007, 60.0000], + [0.0000, 131.0000, 38.0000, 150.0000, 0.1967, 60.0000], + [364.0000, 137.0000, 401.0000, 165.0000, 0.1958, 60.0000], + [398.0000, 95.0000, 538.0000, 139.0000, 0.1868, 79.0000], + [334.0000, 243.0000, 357.0000, 263.0000, 0.1835, 60.0000], + [480.0000, 269.0000, 503.0000, 286.0000, 0.1831, 60.0000], + [184.0000, 302.0000, 229.0000, 320.0000, 0.1784, 60.0000], + [522.0000, 286.0000, 548.0000, 300.0000, 0.1752, 60.0000], + ] + ), + ] + + return [{"bbox": p[:, :4].double(), "scores": p[:, 4].double(), "labels": p[:, 5]} for p in pred], [ + {"bbox": g[:, :4].double(), "labels": g[:, 4], "iscrowd": g[:, 5]} for g in gt + ] + + +def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]: + torch.manual_seed(12) + targets = [] + preds = [] + for _ in range(30): + # Generate some ground truth boxes + n_gt_box = torch.randint(50, (1,)).item() + x1 = torch.randint(641, (n_gt_box, 1)) + y1 = torch.randint(641, (n_gt_box, 1)) + w = 640 * torch.rand((n_gt_box, 1)) + h = 640 * torch.rand((n_gt_box, 1)) + x2 = (x1 + w).clip(max=640) + y2 = (y1 + h).clip(max=640) + category = torch.randint(91, (n_gt_box, 1)) + iscrowd = torch.randint(2, (n_gt_box, 1)) + targets.append(torch.cat((x1, y1, x2, y2, category, iscrowd), dim=1)) + + # Remove some of gt boxes from corresponding predictions + kept_boxes = torch.randint(2, (n_gt_box,), dtype=torch.bool) + n_predicted_box = kept_boxes.sum() + x1 = x1[kept_boxes] + y1 = y1[kept_boxes] + w = w[kept_boxes] + h = h[kept_boxes] + category = category[kept_boxes] + + # Perturb gt boxes in the prediction + perturb_x1 = 640 * (torch.rand_like(x1, dtype=torch.float) - 0.5) + perturb_y1 = 640 * (torch.rand_like(y1, dtype=torch.float) - 0.5) + perturb_w = 640 * (torch.rand_like(w, dtype=torch.float) - 0.5) + perturb_h = 640 * (torch.rand_like(h, dtype=torch.float) - 0.5) + perturb_category = Geometric(0.7).sample((n_predicted_box, 1)) * (2 * torch.randint_like(category, 2) - 1) + + x1 = (x1 + perturb_x1).clip(min=0, max=640) + y1 = (y1 + perturb_y1).clip(min=0, max=640) + w = (w + perturb_w).clip(min=0, max=640) + h = (h + perturb_h).clip(min=0, max=640) + x2 = (x1 + w).clip(max=640) + y2 = (y1 + h).clip(max=640) + category = (category + perturb_category) % 100 + confidence = torch.rand_like(category, dtype=torch.float) + perturbed_gt_boxes = torch.cat((x1, y1, x2, y2, confidence, category), dim=1) + + # Generate some additional prediction boxes + n_additional_pred_boxes = torch.randint(50, (1,)).item() + x1 = torch.randint(641, (n_additional_pred_boxes, 1)) + y1 = torch.randint(641, (n_additional_pred_boxes, 1)) + w = 640 * torch.rand((n_additional_pred_boxes, 1)) + h = 640 * torch.rand((n_additional_pred_boxes, 1)) + x2 = (x1 + w).clip(max=640) + y2 = (y1 + h).clip(max=640) + category = torch.randint(100, (n_additional_pred_boxes, 1)) + confidence = torch.rand_like(category, dtype=torch.float) + additional_pred_boxes = torch.cat((x1, y1, x2, y2, confidence, category), dim=1) + + preds.append(torch.cat((perturbed_gt_boxes, additional_pred_boxes), dim=0)) + + return [{"bbox": p[:, :4], "scores": p[:, 4], "labels": p[:, 5]} for p in preds], [ + {"bbox": g[:, :4], "labels": g[:, 4], "iscrowd": g[:, 5]} for g in targets + ] + + +def create_coco_api( + predictions: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]] +) -> Tuple[COCO, COCO]: + """Create COCO objects from predictions and targets + + Args: + predictions: list of predictions. Each one is a dict containing "bbox", "scores" and "labels" as its keys. The + associated value to "bbox" is a tensor of shape (n, 4) where n stands for the number of detections. + 4 represents top left and bottom right coordinates of a box in the form (x1, y1, x2, y2). The associated + values to "scores" and "labels" are tensors of shape (n,). + targets: list of targets. Each one is a dict containing "bbox", "labels" and "iscrowd" as its keys. The + associated values to "bbox" and "labels" are the same as those of the ``predictions``. The associated + value to "iscrowd" is a tensor of shape (n,) which determines if ground truth boxes are crowd or not. + """ + ann_id = 1 + coco_gt = COCO() + dataset = {"images": [], "categories": [], "annotations": []} + + for idx, target in enumerate(targets): + dataset["images"].append({"id": idx}) + bboxes = target["bbox"].clone() + bboxes[:, 2:4] -= bboxes[:, 0:2] + for i in range(bboxes.shape[0]): + bbox = bboxes[i].tolist() + area = bbox[2] * bbox[3] + ann = { + "image_id": idx, + "bbox": bbox, + "category_id": target["labels"][i].item(), + "area": area, + "iscrowd": target["iscrowd"][i].item(), + "id": ann_id, + } + dataset["annotations"].append(ann) + ann_id += 1 + dataset["categories"] = [{"id": i} for i in range(0, 91)] + coco_gt.dataset = dataset + coco_gt.createIndex() + + prediction_tensors = [] + for idx, prediction in enumerate(predictions): + bboxes = prediction["bbox"].clone() + bboxes[:, 2:4] -= bboxes[:, 0:2] + prediction_tensors.append( + torch.cat( + [ + torch.tensor(idx).repeat(bboxes.shape[0], 1), + bboxes, + prediction["scores"].unsqueeze(1), + prediction["labels"].unsqueeze(1), + ], + dim=1, + ) + ) + predictions = torch.cat(prediction_tensors, dim=0) + coco_dt = coco_gt.loadRes(predictions.numpy()) + return coco_dt, coco_gt + + +def pycoco_mAP( + predictions: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]] +) -> Tuple[float, float, float]: + """ + Returned values belong to IOU thresholds of [0.5, 0.55, ..., 0.95], [0.5] and [0.75] respectively. + """ + coco_dt, coco_gt = create_coco_api(predictions, targets) + eval = COCOeval(coco_gt, coco_dt, iouType="bbox") + eval.evaluate() + eval.accumulate() + eval.summarize() + return eval.stats[0], eval.stats[1], eval.stats[2] + + +Sample = namedtuple("Sample", ["data", "mAP", "length"]) + + +@pytest.fixture( + params=[ + ("coco2017", "full"), + ("coco2017", "with_an_empty_pred"), + ("coco2017", "with_an_empty_gt"), + ("coco2017", "with_an_empty_pred_and_gt"), + ("random", "full"), + ("random", "with_an_empty_pred"), + ("random", "with_an_empty_gt"), + ("random", "with_an_empty_pred_and_gt"), + ] +) +def sample(request) -> Sample: + data = coco_val2017_sample() if request.param[0] == "coco2017" else random_sample() + if request.param[1] == "with_an_empty_pred": + data[0][1] = { + "bbox": torch.zeros(0, 4), + "scores": torch.zeros( + 0, + ), + "labels": torch.zeros( + 0, + ), + } + elif request.param[1] == "with_an_empty_gt": + data[1][0] = { + "bbox": torch.zeros(0, 4), + "labels": torch.zeros( + 0, + ), + "iscrowd": torch.zeros( + 0, + ), + } + elif request.param[1] == "with_an_empty_pred_and_gt": + data[0][0] = { + "bbox": torch.zeros(0, 4), + "scores": torch.zeros( + 0, + ), + "labels": torch.zeros( + 0, + ), + } + data[0][1] = { + "bbox": torch.zeros(0, 4), + "scores": torch.zeros( + 0, + ), + "labels": torch.zeros( + 0, + ), + } + data[1][0] = { + "bbox": torch.zeros(0, 4), + "labels": torch.zeros( + 0, + ), + "iscrowd": torch.zeros( + 0, + ), + } + data[1][2] = { + "bbox": torch.zeros(0, 4), + "labels": torch.zeros( + 0, + ), + "iscrowd": torch.zeros( + 0, + ), + } + mAP = pycoco_mAP(*data) + + return Sample(data, mAP, len(data[0])) + + +def test_wrong_input(): + with pytest.raises(ValueError, match="Currently, the only available flavor for ObjectDetectionMAP is 'COCO'"): + ObjectDetectionMAP(flavor="wrong flavor") + + +def test_empty_data(): + """ + Note that PyCOCO returns -1 when threre's no ground truth data. + """ + + metric = ObjectDetectionMAP() + metric.update( + ( + {"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}, + {"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))}, + ) + ) + assert len(metric._tp) == 0 + assert len(metric._fp) == 0 + assert len(metric._P) == 0 + assert metric._num_classes == 0 + assert metric.compute() == -1 + + metric = ObjectDetectionMAP() + metric.update( + ( + {"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}, + { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "iscrowd": torch.zeros((1,)), + "labels": torch.ones((1,)), + }, + ) + ) + assert len(metric._tp) == 0 + assert len(metric._fp) == 0 + assert len(metric._P) == 1 and metric._P[1] == 1 + assert metric._num_classes == 2 + assert metric.compute() == 0 + + metric = ObjectDetectionMAP() + pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.9]), + "labels": torch.tensor([5]), + } + target = {"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))} + metric.update((pred, target)) + assert (5 in metric._tp) and metric._tp[5][0].shape[1] == 1 + assert (5 in metric._fp) and metric._fp[5][0].shape[1] == 1 + assert len(metric._P) == 0 + assert metric._num_classes == 6 + assert metric.compute() == pycoco_mAP([pred], [target])[0] + + +def test_no_torchvision(): + with patch.dict(sys.modules, {"torchvision.ops.boxes": None}): + with pytest.raises(ModuleNotFoundError, match=r"This metric requires torchvision to be installed."): + ObjectDetectionMAP() + + +def test_iou(sample): + m = ObjectDetectionMAP() + from pycocotools.mask import iou as pycoco_iou + + for pred, tgt in zip(*sample.data): + pred_bbox = pred["bbox"].double() + tgt_bbox = tgt["bbox"].double() + if not pred_bbox.shape[0] or not tgt_bbox.shape[0]: + continue + iscrowd = tgt["iscrowd"] + + ignite_iou = m.box_iou(pred_bbox, tgt_bbox, iscrowd.bool()) + + pred_bbox[:, 2:4] -= pred_bbox[:, :2] + tgt_bbox[:, 2:4] -= tgt_bbox[:, :2] + pyc_iou = pycoco_iou(pred_bbox.numpy(), tgt_bbox.numpy(), iscrowd.int()) + + equal = ignite_iou.numpy() == pyc_iou + assert equal.all() + + +def test_iou_thresholding(): + metric = ObjectDetectionMAP(iou_thresholds=[0.0, 0.3, 0.5, 0.75]) + + pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.8]), + "labels": torch.tensor([1]), + } + gt = {"bbox": torch.tensor([[0.0, 0.0, 50.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])} + metric.update((pred, gt)) + assert (metric._tp[1][0] == torch.tensor([[True], [True], [True], [False]])).all() + + pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.8]), + "labels": torch.tensor([1]), + } + gt = {"bbox": torch.tensor([[100.0, 0.0, 200.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])} + metric.update((pred, gt)) + assert (metric._tp[1][1] == torch.tensor([[True], [False], [False], [False]])).all() + + +def test_matching(): + """ + PyCOCO matching rules: + 1. The higher confidence in a prediction, the sooner decision is made for it. + If there's equal confidence in two predictions, the dicision is first made + for the one who comes earlier. + 2. Each ground truth box is matched with at most one prediction. Crowd ground + truth is the exception. A prediction matched with a crowd gt would get ignored. + 3. Among many plausible ground truth boxes, a prediction is matched with the + one which has the highest mutual IOU. If two ground truth boxes have the + same IOU with a prediction, the later one is matched. + 4. A non-crowd ground truth has priority over a crowd ground truth in getting + matched with a prediction in the sense that even if the crowd ground truth + has a higher IOU, the non-crowd one gets matched if its IOU is viable. + """ + metric = ObjectDetectionMAP(iou_thresholds=[0.2]) + + rule_1_pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.8, 0.9]), + "labels": torch.tensor([1, 1]), + } + rule_1_gt = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "iscrowd": torch.zeros((1,)), + "labels": torch.tensor([1]), + } + metric.update((rule_1_pred, rule_1_gt)) + assert (metric._tp[1][0] == torch.tensor([[False, True]])).all() + assert (metric._fp[1][0] == torch.tensor([[True, False]])).all() + + rule_1_and_2_pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.9, 0.9]), + "labels": torch.tensor([1, 1]), + } + rule_1_and_2_gt = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "iscrowd": torch.zeros((1,)), + "labels": torch.tensor([1]), + } + metric.update((rule_1_and_2_pred, rule_1_and_2_gt)) + assert (metric._tp[1][1] == torch.tensor([[True, False]])).all() + assert (metric._fp[1][1] == torch.tensor([[False, True]])).all() + + rule_2_pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.9, 0.9]), + "labels": torch.tensor([1, 1]), + } + rule_2_gt = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "iscrowd": torch.tensor([1]), + "labels": torch.tensor([1]), + } + metric.update((rule_2_pred, rule_2_gt)) + assert (metric._tp[1][2] == torch.tensor([[False, False]])).all() + assert (metric._fp[1][2] == torch.tensor([[False, False]])).all() + + rule_2_and_3_pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [100.0, 0.0, 200.0, 100.0]]), + "scores": torch.tensor([0.9, 0.9]), + "labels": torch.tensor([1, 1]), + } + rule_2_and_3_gt = { + "bbox": torch.tensor([[0.0, 0.0, 25.0, 50.0], [50.0, 0.0, 150.0, 100.0]]), + "iscrowd": torch.zeros((2,)), + "labels": torch.tensor([1, 1]), + } + metric.update((rule_2_and_3_pred, rule_2_and_3_gt)) + assert (metric._tp[1][3] == torch.tensor([[True, False]])).all() + assert (metric._fp[1][3] == torch.tensor([[False, True]])).all() + + +def test_compute(sample): + device = idist.device() + metric_50 = ObjectDetectionMAP(iou_thresholds=[0.5], device=device) + metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=device) + metric_50_95 = ObjectDetectionMAP(device=device) + + for prediction, target in zip(*sample.data): + metric_50.update((prediction, target)) + metric_75.update((prediction, target)) + metric_50_95.update((prediction, target)) + + res_50 = metric_50.compute().item() + res_75 = metric_75.compute().item() + res_50_95 = metric_50_95.compute().item() + + pycoco_res_50_95, pycoco_res_50, pycoco_res_75 = sample.mAP + + assert np.allclose(res_50, pycoco_res_50) + assert np.allclose(res_75, pycoco_res_75) + assert np.allclose(res_50_95, pycoco_res_50_95) + + res_50_recompute = metric_50.compute() + res_75_recompute = metric_75.compute() + res_50_95_recompute = metric_50_95.compute() + + assert res_50 == res_50_recompute + assert res_75 == res_75_recompute + assert res_50_95 == res_50_95_recompute + + +def test_integration(sample): + def update(engine, i): + return sample.data[0][i], sample.data[1][i] + + engine = Engine(update) + + device = idist.device() + metric_device = "cpu" if device.type == "xla" else device + metric_50_95 = ObjectDetectionMAP(device=metric_device) + metric_50_95.attach(engine, name="mAP[50-95]") + + engine.run(range(sample.length), max_epochs=1) + + res_50_95 = engine.state.metrics["mAP[50-95]"] + pycoco_res_50_95 = sample.mAP[0] + + assert np.allclose(res_50_95, pycoco_res_50_95) + + +def test_distrib_update_compute(distributed, sample): + rank_samples_cnt = ceil(sample.length / idist.get_world_size()) + rank = idist.get_rank() + rank_samples_range = slice(rank_samples_cnt * rank, rank_samples_cnt * (rank + 1)) + + device = idist.device() + metric_device = "cpu" if device.type == "xla" else device + metric_50 = ObjectDetectionMAP(iou_thresholds=[0.5], device=metric_device) + metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=metric_device) + metric_50_95 = ObjectDetectionMAP(device=metric_device) + + for prediction, target in zip(sample.data[0][rank_samples_range], sample.data[1][rank_samples_range]): + metric_50.update((prediction, target)) + metric_75.update((prediction, target)) + metric_50_95.update((prediction, target)) + + res_50 = metric_50.compute() + res_75 = metric_75.compute() + res_50_95 = metric_50_95.compute() + + pycoco_res_50_95, pycoco_res_50, pycoco_res_75 = sample.mAP + + assert np.allclose(res_50_95, pycoco_res_50_95) + assert np.allclose(res_50, pycoco_res_50) + assert np.allclose(res_75, pycoco_res_75) + + res_50_recompute = metric_50.compute() + res_75_recompute = metric_75.compute() + res_50_95_recompute = metric_50_95.compute() + + assert res_50_recompute == res_50 + assert res_75_recompute == res_75 + assert res_50_95_recompute == res_50_95 From 24fe980edb89ceffdf4927782c82ea06d239a5bd Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Sun, 28 May 2023 12:59:18 +0330 Subject: [PATCH 02/41] Some improvements Removed allow_multiple... Renamed average_operand Renamed _measure_recall... to _compute_recall... --- ignite/metrics/mean_average_precision.py | 156 ++++++++---------- ignite/metrics/vision/object_detection_map.py | 29 +++- .../metrics/test_mean_average_precision.py | 79 +++------ .../vision/test_object_detection_map.py | 59 +++++++ 4 files changed, 173 insertions(+), 150 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index fe7c76c9f98..d4ed4294209 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -22,10 +22,9 @@ class MeanAveragePrecision(_BasePrecisionRecall): def __init__( self, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - average_operand: Optional[Literal["precision", "max-precision"]] = "precision", + average: Optional[Literal["precision", "max-precision"]] = "precision", class_mean: Optional[Literal["micro", "macro", "weighted", "with_other_dims"]] = "macro", classification_is_multilabel: bool = False, - allow_multiple_recalls_at_single_threshold: bool = False, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: @@ -38,22 +37,23 @@ def __init__( Mean average precision is the computed by taking the mean of this average precision over different classes and possibly some additional dimensions in the detection task. - For detection tasks user must subclass this metric and implement its :meth:`do_matching` - method to provide the metric with desired matching logic. Then this method is called internally in - :meth:`update` method on prediction-target pairs. For classification, all the binary, multiclass and - multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be set to true. + For detection tasks user should use downstream metrics like + :class:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP` or subclass this metric and implement + its :meth:`_do_matching` method to provide the metric with desired matching logic. Then this method is called + internally in :meth:`update` method on prediction-target pairs. For classification, all the binary, multiclass + and multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be set to true. `mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean`` determines how to take this mean. In the detection tasks, it's possible to take mean of the average precision in other respects as well e.g. IoU threshold in an object detection task. To this end, average precision - corresponding to each value of IoU thresholds should get measured in :meth:`do_matching`. Please refer to - :meth:`do_matching` for more info on this. + corresponding to each value of IoU thresholds should get measured in :meth:`_do_matching`. Please refer to + :meth:`_do_matching` for more info on this. Args: rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need to be sorted. If missing, thresholds are considered automatically using the data. - average_operand: one of values "precision" or "max-precision". In the former case, the precision at a + average: one of values "precision" or "max-precision". In the former case, the precision at a recall threshold is used for that threshold: .. math:: @@ -62,7 +62,7 @@ def __init__( :math:`r` stands for recall thresholds and :math:`P` for precision values. :math:`r_0` is set to zero. In the latter case, the maximum precision across thresholds greater or equal a recall threshold is - considered as the summation operand; In other words, the precision peek across lower or equall + considered as the summation operand; In other words, the precision peek across lower or equal sensivity levels is used for a recall threshold: .. math:: @@ -108,11 +108,6 @@ def __init__( classification_is_multilabel: Used in classification task and determines if the data is multilabel or not. Default False. - allow_multiple_recalls_at_single_threshold: When there are predictions with the same scores, it's faster to - consider those predictions associated with different thresholds in the course of measuring recall - values, but it's not logically correct since those predictions are associated with a single threshold, - thus outputing a single recall value. This option is added mainly due to some downstream mAP metrics - which allow such a thing in their computation e.g. pycocotools' mAP. Default False. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and @@ -127,16 +122,14 @@ def __init__( else: self.rec_thresholds = None - if average_operand not in ("precision", "max-precision"): - raise ValueError(f"Wrong `average_operand` parameter, given {average_operand}") - self.average_operand = average_operand + if average not in ("precision", "max-precision"): + raise ValueError(f"Wrong `average` parameter, given {average}") + self.average = average if class_mean is not None and class_mean not in ("micro", "macro", "weighted", "with_other_dims"): raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") self.class_mean = class_mean - self.allow_multiple_recalls_at_single_threshold = allow_multiple_recalls_at_single_threshold - super(_BasePrecisionRecall, self).__init__( output_transform=output_transform, is_multilabel=classification_is_multilabel, device=device ) @@ -169,7 +162,7 @@ def reset(self) -> None: Reset method of the metric """ super(_BasePrecisionRecall, self).reset() - if self.do_matching.__func__ == MeanAveragePrecision.do_matching: # type: ignore[attr-defined] + if self._do_matching.__func__ == MeanAveragePrecision._do_matching: # type: ignore[attr-defined] self._task: Literal["classification", "detection"] = "classification" else: self._task = "detection" @@ -202,7 +195,7 @@ def _check_matching_output_shape( ) -> None: if not (tps.keys() == fps.keys() == scores.keys()): raise ValueError( - "Returned TP, FP and scores dictionaries from do_matching should have" + "Returned TP, FP and scores dictionaries from _do_matching should have" f" the same keys (classes), given {tps.keys()}, {fps.keys()} and {scores.keys()}" ) try: @@ -228,7 +221,7 @@ def _check_matching_output_shape( else: if self_tp_or_fp[cls][-1].shape[:-1] != new_tp_or_fp[cls].shape[:-1]: raise ValueError( - f"Tensors in returned {name} from do_matching should not change in shape " + f"Tensors in returned {name} from _do_matching should not change in shape " "except possibly in the last dimension which is the dimension of samples. Given " f"{self_tp_or_fp[cls][-1].shape} and {new_tp_or_fp[cls].shape}" ) @@ -278,13 +271,13 @@ def _classification_prepare_output( return scores, P - def do_matching( + def _do_matching( self, pred: Any, target: Any ) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]: r""" Matching logic holder of the metric for detection tasks. - User must implement this method by subclassing the metric. There is no constraint on type and shape of + The developer must implement this method by subclassing the metric. There is no constraint on type and shape of ``pred`` and ``target``, but the method should return a quadrople of dictionaries containing TP, FP, P (actual positive) counts and scores for each class respectively. Please note that class numbers start from zero. @@ -315,7 +308,8 @@ def do_matching( `(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores. """ raise NotImplementedError( - "Please subclass MeanAveragePrecision and implement `do_matching` method" " to use the metric in detection." + "Please subclass MeanAveragePrecision and implement `_do_matching` method" + " to use the metric in detection." ) @reinit__is_reduced @@ -324,7 +318,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor Args: output: a binary tuple. It should consist of prediction and target tensors in the classification case but - for detection it is the same as the implemented-by-user :meth:`do_matching`. + for detection it is the same as the implemented-by-user :meth:`_do_matching`. For classification, this metric follows the same rules on ``output`` members shape as the :meth:`Precision.update ` except for ``y_pred`` of binary and multilabel @@ -341,7 +335,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor P.to(self._device, dtype=torch.uint8 if self._type != "multiclass" else torch.long) ) else: - tps, fps, ps, scores_dict = self.do_matching(output[0], output[1]) + tps, fps, ps, scores_dict = self._do_matching(output[0], output[1]) self._check_matching_output_shape(tps, fps, scores_dict) for cls in tps: self._tp[cls].append(tps[cls].to(device=self._device, dtype=torch.uint8)) @@ -353,7 +347,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor if classes: self._num_classes = max(max(classes) + 1, self._num_classes) - def _measure_recall_and_precision( + def _compute_recall_and_precision( self, TP: torch.Tensor, FP: Union[torch.Tensor, None], scores: torch.Tensor, P: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Measuring recall & precision which is the common operation among different settings of the metric. @@ -363,34 +357,30 @@ def _measure_recall_and_precision( classification task. ``...`` stands for the additional dimensions in the detection task. Finally, \#unique scores represents number of unique scores in ``scores`` which is actually the number of thresholds. - This method is called on a per class basis in the detection task and if - ``allow_multiple_recalls_at_single_threshold=False``. - - =========================== ================================== =================================== + ============== ====================== Detection task - -------------------------------------------------------------------------------------------------- - **Object**/ **Condition** ``allow_multiple_recalls...=True`` ``allow_multiple_recalls...=False`` - =========================== ================================== =================================== - TP and FP (..., N\ :sub:`pred`) (..., N\ :sub:`pred`) - scores (N\ :sub:`pred`,) (N\ :sub:`pred`,) - P () (A single float) () (A single float) - recall (..., N\ :sub:`pred`) (..., \#unique scores) - precision (..., N\ :sub:`pred`) (..., \#unique scores) - =========================== ===================== - - =========================== ================================== =================================== + ------------------------------------- + **Object** **Shape** + ============== ====================== + TP and FP (..., N\ :sub:`pred`) + scores (N\ :sub:`pred`,) + P () (A single float) + recall (..., \#unique scores) + precision (..., \#unique scores) + ============== ====================== + + =================== ======================================= Classification task - -------------------------------------------------------------------------------------------------- - **Object**/ **Condition** ``allow_multiple_recalls...=True`` ``allow_multiple_recalls...=False`` - =========================== ================================== =================================== - TP (C, N\ :sub:`pred`) (N\ :sub:`pred`,) - FP (C, N\ :sub:`pred`) None (FP is computed here to be - faster) - scores (C, N\ :sub:`pred`) (N\ :sub:`pred`,) - P (C,) () (A single float) - recall (C, N\ :sub:`pred`) (\#unique scores,) - precision (C, N\ :sub:`pred`) (\#unique scores,) - =========================== ================================== =================================== + ----------------------------------------------------------- + **Object** **Shape** + =================== ======================================= + TP (N\ :sub:`pred`,) + FP None (FP is computed here to be faster) + scores (N\ :sub:`pred`,) + P () (A single float) + recall (\#unique scores,) + precision (\#unique scores,) + =================== ======================================= Returns: `(recall, precision)` @@ -398,32 +388,24 @@ def _measure_recall_and_precision( indices = torch.argsort(scores, dim=-1, stable=True, descending=True) tp = TP.take_along_dim(indices, dim=-1) if self._task == "classification" else TP[..., indices] tp_summation = tp.cumsum(dim=-1).double() - if self._task == "detection" or self.allow_multiple_recalls_at_single_threshold: - fp = ( - cast(torch.Tensor, FP).take_along_dim(indices, dim=-1) - if self._task == "classification" - else cast(torch.Tensor, FP)[..., indices] - ) + + # Adopted from Scikit-learn's implementation + unique_scores_indices = torch.nonzero( + scores.take_along_dim(indices).diff(append=(scores.max() + 1).unsqueeze(dim=0)), as_tuple=True + )[0] + tp_summation = tp_summation[..., unique_scores_indices] + if self._task == "classification": + fp_summation = (unique_scores_indices + 1) - tp_summation + else: + fp = cast(torch.Tensor, FP)[..., indices] fp_summation = fp.cumsum(dim=-1).double() - if not self.allow_multiple_recalls_at_single_threshold: - # Adopted from Scikit-learn's implementation - unique_scores_indices = torch.nonzero( - scores.take_along_dim(indices).diff(append=(scores.max() + 1).unsqueeze(dim=0)), as_tuple=True - )[0] - tp_summation = tp_summation[..., unique_scores_indices] - if self._task == "classification": - fp_summation = (unique_scores_indices + 1) - tp_summation - else: - fp_summation = fp_summation[..., unique_scores_indices] + fp_summation = fp_summation[..., unique_scores_indices] - if self._task == "classification" and self.allow_multiple_recalls_at_single_threshold: - recall = torch.where(P == 0, 1, tp_summation.T / P).T - elif self._task == "classification" and P == 0: + if self._task == "classification" and P == 0: recall = torch.ones_like(tp_summation, device=self._device, dtype=torch.bool) else: recall = tp_summation / P - # precision = tp_summation / (fp_summation + tp_summation + torch.finfo(torch.double).eps) - # or + predicted_positive = tp_summation + fp_summation precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) return recall, precision @@ -440,7 +422,7 @@ def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tens average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ precision_integrand = ( - precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average_operand == "max-precision" else precision + precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision ) if self.rec_thresholds is not None: rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1)) @@ -549,7 +531,7 @@ def compute(self) -> Union[torch.Tensor, float]: if TP[cls].size(-1) == 0: average_precisions[cls] = 0 continue - recall, precision = self._measure_recall_and_precision(TP[cls], FP[cls], scores[cls], P[cls]) + recall, precision = self._compute_recall_and_precision(TP[cls], FP[cls], scores[cls], P[cls]) average_precision_for_cls_across_other_dims = self._measure_average_precision(recall, precision) if self.class_mean != "with_other_dims": average_precisions[cls] = average_precision_for_cls_across_other_dims.mean() @@ -607,7 +589,7 @@ def compute(self) -> Union[torch.Tensor, float]: ) ) P = P.sum() - recall, precision = self._measure_recall_and_precision(TP_micro, FP_micro, scores_micro, P) + recall, precision = self._compute_recall_and_precision(TP_micro, FP_micro, scores_micro, P) return self._measure_average_precision(recall, precision).mean() else: rank_P = ( @@ -644,16 +626,12 @@ def compute(self) -> Union[torch.Tensor, float]: P = P.reshape(1, -1) scores_classification = scores_classification.view(1, -1) P_count = P.sum(dim=-1) - if self.allow_multiple_recalls_at_single_threshold: - recall, precision = self._measure_recall_and_precision(P, 1 - P, scores_classification, P_count) - average_precisions = self._measure_average_precision(recall, precision) - else: - average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double) - for cls in range(len(P_count)): - recall, precision = self._measure_recall_and_precision( - P[cls], None, scores_classification[cls], P_count[cls] - ) - average_precisions[cls] = self._measure_average_precision(recall, precision) + average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double) + for cls in range(len(P_count)): + recall, precision = self._compute_recall_and_precision( + P[cls], None, scores_classification[cls], P_count[cls] + ) + average_precisions[cls] = self._measure_average_precision(recall, precision) if self._type == "binary": return average_precisions.item() if self.class_mean is None: diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index 35bc1ba5ee1..1789a50f937 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -86,13 +86,34 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo super(ObjectDetectionMAP, self).__init__( rec_thresholds=rec_thresholds, - average_operand="max-precision" if flavor == "COCO" else "precision", + average="max-precision" if flavor == "COCO" else "precision", class_mean="with_other_dims", - allow_multiple_recalls_at_single_threshold=flavor == "COCO", output_transform=output_transform, device=device, ) + def _compute_recall_and_precision( + self, TP: torch.Tensor, FP: Union[torch.Tensor, None], scores: torch.Tensor, P: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Measuring recall & precision + + This method is overriden since in the pycocotools reference implementation, when there are predictions with the + same scores, they're considered associated with different thresholds in the course of measuring recall + values, although it's not logically correct as those predictions are really associated with a single threshold, + thus outputing a single recall value. + """ + indices = torch.argsort(scores, dim=-1, stable=True, descending=True) + tp = TP[..., indices] + tp_summation = tp.cumsum(dim=-1).double() + fp = cast(torch.Tensor, FP)[..., indices] + fp_summation = fp.cumsum(dim=-1).double() + + recall = tp_summation / P + predicted_positive = tp_summation + fp_summation + precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) + + return recall, precision + def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: """Measuring average precision. This method is overriden since :math:`1/#recall_thresholds` is used instead of :math:`r_k - r_{k-1}` @@ -110,7 +131,7 @@ def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tens return super()._measure_average_precision(recall, precision) precision_integrand = ( - precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average_operand == "max-precision" else precision + precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision ) rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) @@ -124,7 +145,7 @@ def compute(self) -> Union[torch.Tensor, float]: return -1 return super().compute() - def do_matching( + def _do_matching( self, pred: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor] ) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]: """ diff --git a/tests/ignite/metrics/test_mean_average_precision.py b/tests/ignite/metrics/test_mean_average_precision.py index 4432d517a1c..8b4c9def745 100644 --- a/tests/ignite/metrics/test_mean_average_precision.py +++ b/tests/ignite/metrics/test_mean_average_precision.py @@ -4,7 +4,6 @@ import pytest import torch from sklearn.metrics import average_precision_score, precision_recall_curve -from sklearn.utils.extmath import stable_cumsum from ignite import distributed as idist from ignite.engine import Engine @@ -21,8 +20,8 @@ def test_wrong_input(): with pytest.raises(TypeError, match="rec_thresholds should be a sequence of floats or a tensor"): MeanAveragePrecision(rec_thresholds={0, 0.2, 0.4, 0.6, 0.8}) - with pytest.raises(ValueError, match="Wrong `average_operand` parameter"): - MeanAveragePrecision(average_operand=1) + with pytest.raises(ValueError, match="Wrong `average` parameter"): + MeanAveragePrecision(average=1) with pytest.raises(ValueError, match="Wrong `class_mean` parameter"): MeanAveragePrecision(class_mean="samples") @@ -50,18 +49,18 @@ def test_wrong_classification_input(): class Dummy_mAP(MeanAveragePrecision): - def do_matching(self, pred: Tuple, target: Tuple): + def _do_matching(self, pred: Tuple, target: Tuple): return *pred, *target -def test_wrong_do_matching(): +def test_wrong__do_matching(): metric = MeanAveragePrecision() with pytest.raises(NotImplementedError, match="Please subclass MeanAveragePrecision and implement"): - metric.do_matching(None, None) + metric._do_matching(None, None) metric = Dummy_mAP() - with pytest.raises(ValueError, match="Returned TP, FP and scores dictionaries from do_matching should have"): + with pytest.raises(ValueError, match="Returned TP, FP and scores dictionaries from _do_matching should have"): metric.update( ( ({1: torch.tensor([True])}, {1: torch.tensor([False])}), @@ -80,7 +79,7 @@ def test_wrong_do_matching(): ) metric.update((({1: torch.tensor([True])}, {1: torch.tensor([False])}), ({1: 1}, {1: torch.tensor([0.8])}))) - with pytest.raises(ValueError, match="Tensors in returned FP from do_matching should not change in shape except"): + with pytest.raises(ValueError, match="Tensors in returned FP from _do_matching should not change in shape except"): metric.update( ( ({1: torch.tensor([False, True])}, {1: torch.tensor([[True, False], [False, False]])}), @@ -127,72 +126,38 @@ def test_update(): assert metric._P[2] == 3 -def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score): - y_true = y_true == 1 - - desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] - y_true = y_true[desc_score_indices] - weight = 1.0 - - tps = stable_cumsum(y_true * weight) - fps = stable_cumsum((1 - y_true) * weight) - ps = tps + fps - precision = np.zeros_like(tps) - np.divide(tps, ps, out=precision, where=(ps != 0)) - if tps[-1] == 0: - recall = np.ones_like(tps) - else: - recall = tps / tps[-1] - - sl = slice(None, None, -1) - return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), None - - -@pytest.mark.parametrize( - "allow_multiple_recalls_at_single_threshold, sklearn_pr_rec_curve", - [ - (False, precision_recall_curve), - (True, sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold), - ], -) -def test__measure_recall_and_precision(allow_multiple_recalls_at_single_threshold, sklearn_pr_rec_curve): +def test__compute_recall_and_precision(): # Classification - m = MeanAveragePrecision(allow_multiple_recalls_at_single_threshold=allow_multiple_recalls_at_single_threshold) + m = MeanAveragePrecision() scores = torch.rand((50,)) y_true = torch.randint(0, 2, (50,)).bool() - precision, recall, _ = sklearn_pr_rec_curve(y_true.numpy(), scores.numpy()) - if allow_multiple_recalls_at_single_threshold: - y_true = y_true.unsqueeze(0) - scores = scores.unsqueeze(0) - FP = ~y_true if allow_multiple_recalls_at_single_threshold else None + precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) + FP = None P = y_true.sum(dim=-1) - ignite_recall, ignite_precision = m._measure_recall_and_precision(y_true, FP, scores, P) + ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, FP, scores, P) assert (ignite_recall.squeeze().flip(0).numpy() == recall[:-1]).all() assert (ignite_precision.squeeze().flip(0).numpy() == precision[:-1]).all() # Classification, when there's no actual positive. Numpy expectedly raises warning. scores = torch.rand((50,)) y_true = torch.zeros((50,)).bool() - precision, recall, _ = sklearn_pr_rec_curve(y_true.numpy(), scores.numpy()) - if allow_multiple_recalls_at_single_threshold: - y_true = y_true.unsqueeze(0) - scores = scores.unsqueeze(0) - FP = ~y_true if allow_multiple_recalls_at_single_threshold else None - P = torch.tensor([0]) if allow_multiple_recalls_at_single_threshold else torch.tensor(0) - ignite_recall, ignite_precision = m._measure_recall_and_precision(y_true, FP, scores, P) + precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) + FP = None + P = torch.tensor(0) + ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, FP, scores, P) assert (ignite_recall.flip(0).numpy() == recall[:-1]).all() assert (ignite_precision.flip(0).numpy() == precision[:-1]).all() # Detection, in the case detector detects all gt objects but also produces some wrong predictions. scores = torch.rand((50,)) y_true = torch.randint(0, 2, (50,)) - m = Dummy_mAP(allow_multiple_recalls_at_single_threshold=allow_multiple_recalls_at_single_threshold) + m = Dummy_mAP() - ignite_recall, ignite_precision = m._measure_recall_and_precision( + ignite_recall, ignite_precision = m._compute_recall_and_precision( y_true.bool(), ~(y_true.bool()), scores, y_true.sum() ) - sklearn_precision, sklearn_recall, _ = sklearn_pr_rec_curve(y_true.numpy(), scores.numpy()) + sklearn_precision, sklearn_recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) assert (ignite_recall.flip(0).numpy() == sklearn_recall[:-1]).all() assert (ignite_precision.flip(0).numpy() == sklearn_precision[:-1]).all() @@ -203,12 +168,12 @@ def test__measure_recall_and_precision(allow_multiple_recalls_at_single_threshol for i in range(6): for j in range(8): y_true[i, j, np.random.choice(50, size=15, replace=False)] = 1 - precision, recall, _ = sklearn_pr_rec_curve(y_true[i, j].numpy(), scores.numpy()) + precision, recall, _ = precision_recall_curve(y_true[i, j].numpy(), scores.numpy()) sklearn_precisions.append(precision[:-1]) sklearn_recalls.append(recall[:-1]) sklearn_precisions = np.array(sklearn_precisions).reshape(6, 8, -1) sklearn_recalls = np.array(sklearn_recalls).reshape(6, 8, -1) - ignite_recall, ignite_precision = m._measure_recall_and_precision( + ignite_recall, ignite_precision = m._compute_recall_and_precision( y_true.bool(), ~(y_true.bool()), scores, torch.tensor(15) ) assert (ignite_recall.flip(-1).numpy() == sklearn_recalls).all() @@ -479,7 +444,7 @@ def update(_, i): # class MatchFirstDetectionFirst_mAP(MeanAveragePrecision): -# def do_matching(self, pred: Tuple[Sequence[int], Sequence[float]] , target: Sequence[int]): +# def _do_matching(self, pred: Tuple[Sequence[int], Sequence[float]] , target: Sequence[int]): # P = dict(Counter(target)) # tp = defaultdict(lambda: []) # scores = defaultdict(lambda: []) diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index ae0b75da9a5..3f64e13e376 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -5,6 +5,7 @@ from unittest.mock import patch import numpy as np +from sklearn.utils.extmath import stable_cumsum np.float = float @@ -800,9 +801,67 @@ def test_matching(): assert (metric._fp[1][3] == torch.tensor([[False, True]])).all() +def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score): + y_true = y_true == 1 + + desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] + y_true = y_true[desc_score_indices] + weight = 1.0 + + tps = stable_cumsum(y_true * weight) + fps = stable_cumsum((1 - y_true) * weight) + ps = tps + fps + precision = np.zeros_like(tps) + np.divide(tps, ps, out=precision, where=(ps != 0)) + if tps[-1] == 0: + recall = np.ones_like(tps) + else: + recall = tps / tps[-1] + + sl = slice(None, None, -1) + return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), None + + +def test__compute_recall_and_precision(): + # Detection, in the case detector detects all gt objects but also produces some wrong predictions. + scores = torch.rand((50,)) + y_true = torch.randint(0, 2, (50,)) + m = ObjectDetectionMAP() + + ignite_recall, ignite_precision = m._compute_recall_and_precision( + y_true.bool(), ~(y_true.bool()), scores, y_true.sum() + ) + sklearn_precision, sklearn_recall, _ = sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold( + y_true.numpy(), scores.numpy() + ) + assert (ignite_recall.flip(0).numpy() == sklearn_recall[:-1]).all() + assert (ignite_precision.flip(0).numpy() == sklearn_precision[:-1]).all() + + # Detection like above but with two additional mean dimensions. + scores = torch.rand((50,)) + y_true = torch.zeros((6, 8, 50)) + sklearn_precisions, sklearn_recalls = [], [] + for i in range(6): + for j in range(8): + y_true[i, j, np.random.choice(50, size=15, replace=False)] = 1 + precision, recall, _ = sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold( + y_true[i, j].numpy(), scores.numpy() + ) + sklearn_precisions.append(precision[:-1]) + sklearn_recalls.append(recall[:-1]) + sklearn_precisions = np.array(sklearn_precisions).reshape(6, 8, -1) + sklearn_recalls = np.array(sklearn_recalls).reshape(6, 8, -1) + ignite_recall, ignite_precision = m._compute_recall_and_precision( + y_true.bool(), ~(y_true.bool()), scores, torch.tensor(15) + ) + assert (ignite_recall.flip(-1).numpy() == sklearn_recalls).all() + assert (ignite_precision.flip(-1).numpy() == sklearn_precisions).all() + + def test_compute(sample): device = idist.device() metric_50 = ObjectDetectionMAP(iou_thresholds=[0.5], device=device) + assert metric_50._task == "detection" metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=device) metric_50_95 = ObjectDetectionMAP(device=device) From e2ac8ee9c7f2db332789ded9a56afd153cf807a5 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Sun, 28 May 2023 14:25:23 +0330 Subject: [PATCH 03/41] Update docs --- ignite/metrics/mean_average_precision.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index d4ed4294209..392511ef761 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -37,17 +37,14 @@ def __init__( Mean average precision is the computed by taking the mean of this average precision over different classes and possibly some additional dimensions in the detection task. - For detection tasks user should use downstream metrics like - :class:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP` or subclass this metric and implement - its :meth:`_do_matching` method to provide the metric with desired matching logic. Then this method is called - internally in :meth:`update` method on prediction-target pairs. For classification, all the binary, multiclass - and multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be set to true. + For detection tasks, user should use downstream metrics like + :class:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP`. For classification, all the binary, + multiclass and multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be + set to true. `mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean`` determines how to take this mean. In the detection tasks, it's possible to take mean of the average precision - in other respects as well e.g. IoU threshold in an object detection task. To this end, average precision - corresponding to each value of IoU thresholds should get measured in :meth:`_do_matching`. Please refer to - :meth:`_do_matching` for more info on this. + in other respects as well e.g. IoU threshold in an object detection task. Args: rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. @@ -317,8 +314,9 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor """Metric update function using prediction and target. Args: - output: a binary tuple. It should consist of prediction and target tensors in the classification case but - for detection it is the same as the implemented-by-user :meth:`_do_matching`. + output: a binary tuple. It should consist of prediction and target tensors in the classification case. + for detection, user should refer to the desired subclass metric e.g. + :meth:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP.update` For classification, this metric follows the same rules on ``output`` members shape as the :meth:`Precision.update ` except for ``y_pred`` of binary and multilabel From 7cf53e13c50706849e3e7ef4afc6a9cb68365130 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Mon, 29 May 2023 22:42:06 +0330 Subject: [PATCH 04/41] Fix a bug in docs Docs has some nasty errors --- ignite/metrics/mean_average_precision.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index 392511ef761..bfd3b548ba1 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -316,11 +316,11 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor Args: output: a binary tuple. It should consist of prediction and target tensors in the classification case. for detection, user should refer to the desired subclass metric e.g. - :meth:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP.update` + :meth:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP` For classification, this metric follows the same rules on ``output`` members shape as the - :meth:`Precision.update ` except for ``y_pred`` of binary and multilabel - data which should be comprised of positive class probabilities here. + :meth:`Precision.update <.metrics.precision.Precision.update>` except for ``y_pred`` of binary and + multilabel data which should be comprised of positive class probabilities here. """ if self._task == "classification": From 4aa9c5db934d197775560a4e176b88167e53f3c0 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 16 Jun 2023 02:46:41 +0330 Subject: [PATCH 05/41] Fix a tiny bug related to allgather --- ignite/distributed/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index aaa8887ac7c..2b5749acd16 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -363,7 +363,7 @@ def _all_gather_tensors_with_shapes( if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group): return [tensor] - max_shape = torch.tensor(shapes).amax(dim=1) + max_shape = torch.tensor(shapes).amax(dim=0) padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist() padded_tensor = torch.nn.functional.pad( tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes)))) From 950c38870d30f6158003979f88b5e641182fcb5f Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 16 Jun 2023 07:16:09 +0330 Subject: [PATCH 06/41] Fix a few bugs --- ignite/distributed/utils.py | 1 - ignite/metrics/mean_average_precision.py | 81 ++++++++----------- .../vision/test_object_detection_map.py | 4 +- 3 files changed, 34 insertions(+), 52 deletions(-) diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 2b5749acd16..fc4cea80d88 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -377,7 +377,6 @@ def _all_gather_tensors_with_shapes( ] ] for rank, shape in enumerate(shapes) - if group is None or rank in group ] diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index bfd3b548ba1..b3ec508bf25 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -434,6 +434,11 @@ def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tens ) return torch.sum(recall_differential * precision_integrand, dim=-1) + def concat_dict_of_tensor_lists_in_key_order( + self, tensor_dict: Dict[int, List[torch.Tensor]], num_keys: int + ) -> torch.Tensor: + return torch.cat(list(itertools.chain(*map(tensor_dict.__getitem__, range(num_keys)))), dim=-1) + def compute(self) -> Union[torch.Tensor, float]: """ Compute method of the metric @@ -475,66 +480,55 @@ def compute(self) -> Union[torch.Tensor, float]: ).tolist() if self.class_mean != "micro": - shapes_across_ranks = { - cls: [ - (*mean_dimensions_shape, num_pred_in_rank) - for num_pred_in_rank in num_preds_per_class_across_ranks[:, cls] + average_precisions = -torch.ones( + (num_classes, *(mean_dimensions_shape if self.class_mean == "with_other_dims" else ())), + device=self._device, + dtype=torch.double, + ) + for cls in range(num_classes): + if P[cls] == 0: + continue + + num_preds_across_ranks = num_preds_per_class_across_ranks[:, [cls]] + if num_preds_across_ranks.sum() == 0: + average_precisions[cls] = 0 + continue + shape_across_ranks = [ + (*mean_dimensions_shape, num_pred_in_rank.item()) for num_pred_in_rank in num_preds_across_ranks ] - for cls in range(num_classes) - } - TP = { - cls: torch.cat( + TP = torch.cat( _all_gather_tensors_with_shapes( torch.cat(self._tp[cls], dim=-1) if self._tp[cls] else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), - shapes_across_ranks[cls], + shape_across_ranks, ), dim=-1, ) - for cls in range(num_classes) - } - FP = { - cls: torch.cat( + FP = torch.cat( _all_gather_tensors_with_shapes( torch.cat(self._fp[cls], dim=-1) if self._fp[cls] else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), - shapes_across_ranks[cls], + shape_across_ranks, ), dim=-1, ) - for cls in range(num_classes) - } - scores = { - cls: torch.cat( + scores = torch.cat( _all_gather_tensors_with_shapes( torch.cat(cast(List[torch.Tensor], self._scores[cls])) if self._scores[cls] else torch.tensor([], dtype=torch.double, device=self._device), - num_preds_per_class_across_ranks[:, [cls]].tolist(), + num_preds_across_ranks.tolist(), ) ) - for cls in range(num_classes) - } - - average_precisions = -torch.ones( - (num_classes, *(mean_dimensions_shape if self.class_mean == "with_other_dims" else ())), - device=self._device, - dtype=torch.double, - ) - for cls in range(num_classes): - if P[cls] == 0: - continue - if TP[cls].size(-1) == 0: - average_precisions[cls] = 0 - continue - recall, precision = self._compute_recall_and_precision(TP[cls], FP[cls], scores[cls], P[cls]) + recall, precision = self._compute_recall_and_precision(TP, FP, scores, P[cls]) average_precision_for_cls_across_other_dims = self._measure_average_precision(recall, precision) if self.class_mean != "with_other_dims": average_precisions[cls] = average_precision_for_cls_across_other_dims.mean() else: average_precisions[cls] = average_precision_for_cls_across_other_dims + if self.class_mean is None: average_precisions[average_precisions == -1] = 0 return average_precisions @@ -549,9 +543,7 @@ def compute(self) -> Union[torch.Tensor, float]: ] TP_micro = torch.cat( _all_gather_tensors_with_shapes( - torch.cat(list(itertools.chain(*map(self._tp.__getitem__, range(num_classes)))), dim=-1).to( - torch.uint8 - ) + self.concat_dict_of_tensor_lists_in_key_order(self._tp, num_classes).to(torch.uint8) if num_preds_across_ranks[idist.get_rank()] else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), shapes_across_ranks_in_micro, @@ -560,9 +552,7 @@ def compute(self) -> Union[torch.Tensor, float]: ).bool() FP_micro = torch.cat( _all_gather_tensors_with_shapes( - torch.cat(list(itertools.chain(*map(self._fp.__getitem__, range(num_classes)))), dim=-1).to( - torch.uint8 - ) + self.concat_dict_of_tensor_lists_in_key_order(self._fp, num_classes).to(torch.uint8) if num_preds_across_ranks[idist.get_rank()] else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), shapes_across_ranks_in_micro, @@ -571,15 +561,8 @@ def compute(self) -> Union[torch.Tensor, float]: ).bool() scores_micro = torch.cat( _all_gather_tensors_with_shapes( - torch.cat( - list( - itertools.chain( - *map( - cast(Dict[int, List[torch.Tensor]], self._scores).__getitem__, - range(num_classes), - ) - ) - ) + self.concat_dict_of_tensor_lists_in_key_order( + cast(Dict[int, List[torch.Tensor]], self._scores), num_classes ) if num_preds_across_ranks[idist.get_rank()] else torch.tensor([], dtype=torch.double, device=self._device), diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 3f64e13e376..8e95630b1ee 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -449,7 +449,7 @@ def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch x2 = (x1 + w).clip(max=640) y2 = (y1 + h).clip(max=640) category = (category + perturb_category) % 100 - confidence = torch.rand_like(category, dtype=torch.float) + confidence = torch.rand_like(category, dtype=torch.double) perturbed_gt_boxes = torch.cat((x1, y1, x2, y2, confidence, category), dim=1) # Generate some additional prediction boxes @@ -461,7 +461,7 @@ def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch x2 = (x1 + w).clip(max=640) y2 = (y1 + h).clip(max=640) category = torch.randint(100, (n_additional_pred_boxes, 1)) - confidence = torch.rand_like(category, dtype=torch.float) + confidence = torch.rand_like(category, dtype=torch.double) additional_pred_boxes = torch.cat((x1, y1, x2, y2, confidence, category), dim=1) preds.append(torch.cat((perturbed_gt_boxes, additional_pred_boxes), dim=0)) From 9f5f79636dac738326c4ba7e158a9b7ae949c5a5 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 16 Jun 2023 08:28:57 +0330 Subject: [PATCH 07/41] Redesign code: Removed generic detection logics. Just that of the COCO is remained Tests are updated --- ignite/metrics/mean_average_precision.py | 603 ++++++------------ ignite/metrics/vision/object_detection_map.py | 263 ++++++-- .../metrics/test_mean_average_precision.py | 294 +-------- .../vision/test_object_detection_map.py | 59 +- 4 files changed, 508 insertions(+), 711 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index b3ec508bf25..6bb75126bd9 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -1,56 +1,37 @@ -import itertools import warnings -from collections import defaultdict -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, cast, List, Optional, Sequence, Tuple, Union import torch from typing_extensions import Literal import ignite.distributed as idist -from ignite.distributed.utils import _all_gather_tensors_with_shapes from ignite.metrics.metric import reinit__is_reduced from ignite.metrics.recall import _BasePrecisionRecall from ignite.utils import to_onehot -class MeanAveragePrecision(_BasePrecisionRecall): - _tp: Dict[int, List[torch.Tensor]] - _fp: Dict[int, List[torch.Tensor]] - _scores: Union[Dict[int, List[torch.Tensor]], List[torch.Tensor]] - _P: Union[Dict[int, int], List[torch.Tensor]] - +class _BaseMeanAveragePrecision(_BasePrecisionRecall): def __init__( self, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - average: Optional[Literal["precision", "max-precision"]] = "precision", + average: Optional[str] = "precision", class_mean: Optional[Literal["micro", "macro", "weighted", "with_other_dims"]] = "macro", classification_is_multilabel: bool = False, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: - r"""Calculate the mean average precision metric i.e. mean of the averaged-over-recall precision for detection - and classification tasks. - - Mean average precision attempts to give a measure of detector or classifier precision at various - sensivity levels a.k.a recall thresholds. This is done by summing precisions at different recall - thresholds weighted by the change in recall, as if the area under precision-recall curve is being computed. - Mean average precision is the computed by taking the mean of this average precision over different classes - and possibly some additional dimensions in the detection task. - - For detection tasks, user should use downstream metrics like - :class:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP`. For classification, all the binary, - multiclass and multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be - set to true. + r"""Base class for Mean Average Precision in classification and detection tasks. - `mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean`` - determines how to take this mean. In the detection tasks, it's possible to take mean of the average precision - in other respects as well e.g. IoU threshold in an object detection task. + Mean average precision is computed by taking the mean of the average precision over different classes + and possibly some additional dimensions in the detection task. ``class_mean`` determines how to take this mean. + In the detection tasks, it's possible to take the mean in other respects as well e.g. IoU threshold in an + object detection task. Args: rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need to be sorted. If missing, thresholds are considered automatically using the data. - average: one of values "precision" or "max-precision". In the former case, the precision at a + average: one of values precision or max-precision. In the former case, the precision at a recall threshold is used for that threshold: .. math:: @@ -73,7 +54,7 @@ def __init__( An 1-dimensional tensor of mean (taken across additional mean dimensions) average precision per class is returned. If there's no ground truth sample for a class, ``0`` is returned for that. - 'micro' + micro Precision is computed counting stats of classes/labels altogether. This option incorporates class in the very precision measurement. @@ -86,7 +67,7 @@ def __init__( For multiclass inputs, this is equivalent with mean average accuracy. - 'weighted' + weighted like macro but considers class/label imbalance. For multiclass input, it computes AP for each class then returns mean of them weighted by support of classes (number of actual samples in each class). For multilabel input, @@ -127,12 +108,7 @@ def __init__( raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") self.class_mean = class_mean - super(_BasePrecisionRecall, self).__init__( - output_transform=output_transform, is_multilabel=classification_is_multilabel, device=device - ) - - if self._task == "classification" and self.class_mean == "with_other_dims": - raise ValueError("class_mean 'with_other_dims' is not compatible with classification.") + super().__init__(output_transform=output_transform, is_multilabel=classification_is_multilabel, device=device) def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor: if isinstance(thresholds, Sequence): @@ -153,25 +129,143 @@ def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], th return cast(torch.Tensor, thresholds) + def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: + """Measuring average precision which is the common operation among different settings of the metric. + + Args: + recall: n-dimensional tensor whose last dimension is the dimension of the samples. Should be ordered in + ascending order in its last dimension. + precision: like ``recall`` in the shape. + + Returns: + average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. + """ + precision_integrand = ( + precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision + ) + if self.rec_thresholds is not None: + rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1)) + rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) + precision_integrand = precision_integrand.take_along_dim( + rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 + ).where(rec_thresh_indices != recall.size(-1), 0) + recall = rec_thresholds + recall_differential = recall.diff( + dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=self._device, dtype=torch.double) + ) + return torch.sum(recall_differential * precision_integrand, dim=-1) + + +class MeanAveragePrecision(_BaseMeanAveragePrecision): + _scores: List[torch.Tensor] + _P: List[torch.Tensor] + + def __init__( + self, + rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + average: Optional["Literal['precision', 'max-precision']"] = "precision", + class_mean: Optional["Literal['micro', 'macro', 'weighted']"] = "macro", + is_multilabel: bool = False, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + ) -> None: + r"""Calculate the mean average precision metric i.e. mean of the averaged-over-recall precision for + classification task. + + Mean average precision attempts to give a measure of detector or classifier precision at various + sensivity levels a.k.a recall thresholds. This is done by summing precisions at different recall + thresholds weighted by the change in recall, as if the area under precision-recall curve is being computed. + Mean average precision is then computed by taking the mean of this average precision over different classes. + + For detection tasks, user should use downstream metrics like + :class:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP`. For classification, all the binary, + multiclass and multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be + set to true. + + `mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean`` + determines how to take this mean. + + Args: + rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. + It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need + to be sorted. If missing, thresholds are considered automatically using the data. + average: one of values "precision" or "max-precision". In the former case, the precision at a + recall threshold is used for that threshold: + + .. math:: + \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) P_k + + :math:`r` stands for recall thresholds and :math:`P` for precision values. :math:`r_0` is set to zero. + + In the latter case, the maximum precision across thresholds greater or equal a recall threshold is + considered as the summation operand; In other words, the precision peek across lower or equal + sensivity levels is used for a recall threshold: + + .. math:: + \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) max(P_{k:}) + + Default is "precision". + class_mean: how to compute mean of the average precision across classes or incorporate class + dimension into computing precision. It's ignored in binary classification. Available options are + + None + An 1-dimensional tensor of mean (taken across additional mean dimensions) average precision per class + is returned. If there's no ground truth sample for a class, ``0`` is returned for that. + + 'micro' + Precision is computed counting stats of classes/labels altogether. This option + incorporates class in the very precision measurement. + + .. math:: + \text{Micro P} = \frac{\sum_{c=1}^C TP_c}{\sum_{c=1}^C TP_c+FP_c} + + where :math:`C` is the number of classes/labels. :math:`c` in :math:`TP_c` + and :math:`FP_c` means that the terms are computed for class/label :math:`c` (in a one-vs-rest + sense in multiclass case). + + For multiclass inputs, this is equivalent to mean average accuracy. + + 'weighted' + like macro but considers class/label imbalance. For multiclass input, + it computes AP for each class then returns mean of them weighted by + support of classes (number of actual samples in each class). For multilabel input, + it computes AP for each label then returns mean of them weighted by support + of labels (number of actual positive samples in each label). + + 'macro' + computes macro precision which is unweighted mean of AP computed across classes/labels. Default. + + is_multilabel: determines if the data is multilabel or not. Default False. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. This metric requires the output + as ``(y_pred, y)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + """ + + super().__init__( + rec_thresholds=rec_thresholds, + average=average, + class_mean=class_mean, + output_transform=output_transform, + classification_is_multilabel=is_multilabel, + device=device, + ) + + if self.class_mean == "with_other_dims": + raise ValueError("class_mean 'with_other_dims' is not compatible with this class.") + @reinit__is_reduced def reset(self) -> None: """ Reset method of the metric """ super(_BasePrecisionRecall, self).reset() - if self._do_matching.__func__ == MeanAveragePrecision._do_matching: # type: ignore[attr-defined] - self._task: Literal["classification", "detection"] = "classification" - else: - self._task = "detection" - self._tp = defaultdict(lambda: []) - self._fp = defaultdict(lambda: []) - if self._task == "detection": - self._scores = defaultdict(lambda: []) - self._P = defaultdict(lambda: 0) - self._num_classes = 0 - else: - self._scores = [] - self._P = [] + self._scores = [] + self._P = [] def _check_binary_multilabel_cases(self, output: Sequence[torch.Tensor]) -> None: # Ignore the check in `_BaseClassification` since `y_pred` consists of probabilities here. @@ -187,45 +281,7 @@ def _check_type(self, output: Sequence[torch.Tensor]) -> None: if self._type == "multiclass" and y.dtype != torch.long: warnings.warn("`y` should be of dtype long when entry type is multiclass", RuntimeWarning) - def _check_matching_output_shape( - self, tps: Dict[int, torch.Tensor], fps: Dict[int, torch.Tensor], scores: Dict[int, torch.Tensor] - ) -> None: - if not (tps.keys() == fps.keys() == scores.keys()): - raise ValueError( - "Returned TP, FP and scores dictionaries from _do_matching should have" - f" the same keys (classes), given {tps.keys()}, {fps.keys()} and {scores.keys()}" - ) - try: - cls = list(tps.keys()).pop() - except IndexError: # No prediction - pass - else: - if tps[cls].dtype not in (torch.bool, torch.uint8): - raise TypeError(f"Tensors in TP and FP dictionaries should be boolean or uint8, given {tps[cls].dtype}") - - if tps[cls].size(-1) != fps[cls].size(-1) != scores[cls].size(0): - raise ValueError( - "Sample dimension of tensors in TP, FP and scores should have equal size per class," - f"given {tps[cls].size(-1)}, {fps[cls].size(-1)} and {scores[cls].size(-1)} for class {cls}" - " respectively." - ) - for self_tp_or_fp, new_tp_or_fp, name in [(self._tp, tps, "TP"), (self._fp, fps, "FP")]: - new_tp_or_fp.keys() - try: - cls = (self_tp_or_fp.keys() & new_tp_or_fp.keys()).pop() - except KeyError: - pass - else: - if self_tp_or_fp[cls][-1].shape[:-1] != new_tp_or_fp[cls].shape[:-1]: - raise ValueError( - f"Tensors in returned {name} from _do_matching should not change in shape " - "except possibly in the last dimension which is the dimension of samples. Given " - f"{self_tp_or_fp[cls][-1].shape} and {new_tp_or_fp[cls].shape}" - ) - - def _classification_prepare_output( - self, y_pred: torch.Tensor, y: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tensor]: """Prepares and returns scores and P tensor. Input and output shapes of the method is as follows. ========== =========== ============ @@ -248,6 +304,7 @@ def _classification_prepare_output( Multiclass (N, ...) (N * ...) ========== =========== ============ """ + y_pred, y = output[0].detach(), output[1].detach() if self._type == "multilabel": num_classes = y_pred.size(1) @@ -268,113 +325,37 @@ def _classification_prepare_output( return scores, P - def _do_matching( - self, pred: Any, target: Any - ) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]: - r""" - Matching logic holder of the metric for detection tasks. - - The developer must implement this method by subclassing the metric. There is no constraint on type and shape of - ``pred`` and ``target``, but the method should return a quadrople of dictionaries containing TP, FP, - P (actual positive) counts and scores for each class respectively. Please note that class numbers start from - zero. - - Values in TP and FP are (m+1)-dimensional tensors of type ``bool`` or ``uint8`` and shape - (D\ :sub:`1`, D\ :sub:`2`, ..., D\ :sub:`m`, n\ :sub:`cls`) in which D\ :sub:`i`\ 's are possible additional - dimensions (excluding the class dimension) mean of the average precision is taken over. n\ :sub:`cls` is the - number of predictions for class `cls` which is the same for TP and FP. - - Note: - TP and FP values are stored as uint8 tensors internally to avoid bool-to-uint8 copies before collective - operations, as PyTorch colective operations `do not `_ - support boolean tensors, at least on Gloo backend. - - - P counts contains the number of ground truth samples for each class. Finally, the values in scores are 1-dim - tensors of shape (n\ :sub:`cls`,) containing score or confidence of the predictions (doesn't need to be in - [0,1]). If there is no prediction or ground truth for a class, it could be absent from (TP, FP, scores) and P - dictionaries respectively. - - Args: - pred: First member of :meth:`update`'s input is given as this argument. There's no constraint on its type - and shape. - target: Second member of :meth:`update`'s input is given as this argument. There's no constraint on its type - and shape. - - Returns: - `(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores. - """ - raise NotImplementedError( - "Please subclass MeanAveragePrecision and implement `_do_matching` method" - " to use the metric in detection." - ) - @reinit__is_reduced - def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor]]) -> None: + def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: """Metric update function using prediction and target. Args: - output: a binary tuple. It should consist of prediction and target tensors in the classification case. - for detection, user should refer to the desired subclass metric e.g. - :meth:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP` + output: a binary tuple consisting of prediction and target tensors - For classification, this metric follows the same rules on ``output`` members shape as the + This metric follows the same rules on ``output`` members shape as the :meth:`Precision.update <.metrics.precision.Precision.update>` except for ``y_pred`` of binary and multilabel data which should be comprised of positive class probabilities here. """ - - if self._task == "classification": - self._check_shape(output) - prediction, target = output[0].detach(), output[1].detach() - self._check_type((prediction, target)) - scores, P = self._classification_prepare_output(prediction, target) - cast(List[torch.Tensor], self._scores).append(scores.to(self._device)) - cast(List[torch.Tensor], self._P).append( - P.to(self._device, dtype=torch.uint8 if self._type != "multiclass" else torch.long) - ) - else: - tps, fps, ps, scores_dict = self._do_matching(output[0], output[1]) - self._check_matching_output_shape(tps, fps, scores_dict) - for cls in tps: - self._tp[cls].append(tps[cls].to(device=self._device, dtype=torch.uint8)) - self._fp[cls].append(fps[cls].to(device=self._device, dtype=torch.uint8)) - cast(Dict[int, List[torch.Tensor]], self._scores)[cls].append(scores_dict[cls].to(self._device)) - for cls in ps: - cast(Dict[int, int], self._P)[cls] += ps[cls] - classes = tps.keys() | ps.keys() - if classes: - self._num_classes = max(max(classes) + 1, self._num_classes) + self._check_shape(output) + self._check_type(output) + scores, P = self._prepare_output(output) + self._scores.append(scores.to(self._device)) + self._P.append(P.to(self._device, dtype=torch.uint8 if self._type != "multiclass" else torch.long)) def _compute_recall_and_precision( - self, TP: torch.Tensor, FP: Union[torch.Tensor, None], scores: torch.Tensor, P: torch.Tensor + self, TP: torch.Tensor, scores: torch.Tensor, P: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Measuring recall & precision which is the common operation among different settings of the metric. + r"""Measuring recall & precision. Shape of function inputs and return values follow the table below. C is the number of classes, 1 for binary - data. N\ :sub:`pred` is the number of detections or predictions which is the same as the number of samples in - classification task. ``...`` stands for the additional dimensions in the detection task. Finally, - \#unique scores represents number of unique scores in ``scores`` which is actually the number of thresholds. - - ============== ====================== - Detection task - ------------------------------------- - **Object** **Shape** - ============== ====================== - TP and FP (..., N\ :sub:`pred`) - scores (N\ :sub:`pred`,) - P () (A single float) - recall (..., \#unique scores) - precision (..., \#unique scores) - ============== ====================== + data. N is the number of samples. Finally, \#unique scores represents number of unique scores in ``scores`` + which is actually the number of thresholds. =================== ======================================= - Classification task - ----------------------------------------------------------- **Object** **Shape** =================== ======================================= - TP (N\ :sub:`pred`,) - FP None (FP is computed here to be faster) - scores (N\ :sub:`pred`,) + TP (N,) + scores (N,) P () (A single float) recall (\#unique scores,) precision (\#unique scores,) @@ -384,23 +365,18 @@ def _compute_recall_and_precision( `(recall, precision)` """ indices = torch.argsort(scores, dim=-1, stable=True, descending=True) - tp = TP.take_along_dim(indices, dim=-1) if self._task == "classification" else TP[..., indices] - tp_summation = tp.cumsum(dim=-1).double() + tp_summation = TP[..., indices].cumsum(dim=-1).double() # Adopted from Scikit-learn's implementation unique_scores_indices = torch.nonzero( scores.take_along_dim(indices).diff(append=(scores.max() + 1).unsqueeze(dim=0)), as_tuple=True )[0] tp_summation = tp_summation[..., unique_scores_indices] - if self._task == "classification": - fp_summation = (unique_scores_indices + 1) - tp_summation - else: - fp = cast(torch.Tensor, FP)[..., indices] - fp_summation = fp.cumsum(dim=-1).double() - fp_summation = fp_summation[..., unique_scores_indices] + fp_summation = (unique_scores_indices + 1) - tp_summation - if self._task == "classification" and P == 0: - recall = torch.ones_like(tp_summation, device=self._device, dtype=torch.bool) + if P == 0: + # To be aligned with Scikit-Learn + recall = torch.ones_like(tp_summation, device=self._device, dtype=torch.float) else: recall = tp_summation / P @@ -408,216 +384,57 @@ def _compute_recall_and_precision( precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) return recall, precision - def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: - """Measuring average precision which is the common operation among different settings of the metric. - - Args: - recall: n-dimensional tensor whose last dimension is the dimension of the samples. Should be ordered in - ascending order in its last dimension. - precision: like ``recall`` in the shape. - - Returns: - average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. - """ - precision_integrand = ( - precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision - ) - if self.rec_thresholds is not None: - rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1)) - rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) - precision_integrand = precision_integrand.take_along_dim( - rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 - ).where(rec_thresh_indices != recall.size(-1), 0) - recall = rec_thresholds - recall_differential = recall.diff( - dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=self._device, dtype=torch.double) - ) - return torch.sum(recall_differential * precision_integrand, dim=-1) - - def concat_dict_of_tensor_lists_in_key_order( - self, tensor_dict: Dict[int, List[torch.Tensor]], num_keys: int - ) -> torch.Tensor: - return torch.cat(list(itertools.chain(*map(tensor_dict.__getitem__, range(num_keys)))), dim=-1) - def compute(self) -> Union[torch.Tensor, float]: """ Compute method of the metric """ - num_classes = int(idist.all_reduce(self._num_classes or 0, "MAX")) - if not num_classes: - return 0.0 - - if self._task == "detection": - P = cast( - torch.Tensor, - idist.all_reduce(torch.tensor(list(map(self._P.__getitem__, range(num_classes))), device=self._device)), - ) - num_preds = torch.tensor( - [sum([tp.shape[-1] for tp in self._tp[cls]]) if self._tp[cls] else 0 for cls in range(num_classes)], - device=self._device, - ) - num_preds_per_class_across_ranks = torch.stack( - cast(torch.Tensor, idist.all_gather(num_preds)).split(split_size=num_classes) - ) - if num_preds_per_class_across_ranks.sum() == 0: - return ( - 0.0 - if self.class_mean is not None - else torch.zeros((num_classes,), dtype=torch.double, device=self._device) - ) - a_nonempty_rank, its_class_with_pred = list(zip(*torch.where(num_preds_per_class_across_ranks != 0))).pop(0) - a_nonempty_rank = a_nonempty_rank.item() - its_class_with_pred = its_class_with_pred.item() - mean_dimensions_shape = cast( - torch.Tensor, - idist.broadcast( - torch.tensor(self._tp[its_class_with_pred][-1].shape[:-1], device=self._device) - if idist.get_rank() == a_nonempty_rank - else None, - a_nonempty_rank, - safe_mode=True, - ), - ).tolist() - - if self.class_mean != "micro": - average_precisions = -torch.ones( - (num_classes, *(mean_dimensions_shape if self.class_mean == "with_other_dims" else ())), - device=self._device, - dtype=torch.double, - ) - for cls in range(num_classes): - if P[cls] == 0: - continue - - num_preds_across_ranks = num_preds_per_class_across_ranks[:, [cls]] - if num_preds_across_ranks.sum() == 0: - average_precisions[cls] = 0 - continue - shape_across_ranks = [ - (*mean_dimensions_shape, num_pred_in_rank.item()) for num_pred_in_rank in num_preds_across_ranks - ] - TP = torch.cat( - _all_gather_tensors_with_shapes( - torch.cat(self._tp[cls], dim=-1) - if self._tp[cls] - else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), - shape_across_ranks, - ), - dim=-1, - ) - FP = torch.cat( - _all_gather_tensors_with_shapes( - torch.cat(self._fp[cls], dim=-1) - if self._fp[cls] - else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), - shape_across_ranks, - ), - dim=-1, - ) - scores = torch.cat( - _all_gather_tensors_with_shapes( - torch.cat(cast(List[torch.Tensor], self._scores[cls])) - if self._scores[cls] - else torch.tensor([], dtype=torch.double, device=self._device), - num_preds_across_ranks.tolist(), - ) - ) - recall, precision = self._compute_recall_and_precision(TP, FP, scores, P[cls]) - average_precision_for_cls_across_other_dims = self._measure_average_precision(recall, precision) - if self.class_mean != "with_other_dims": - average_precisions[cls] = average_precision_for_cls_across_other_dims.mean() - else: - average_precisions[cls] = average_precision_for_cls_across_other_dims - - if self.class_mean is None: - average_precisions[average_precisions == -1] = 0 - return average_precisions - elif self.class_mean == "weighted": - return torch.dot(P.double(), average_precisions) / P.sum() - else: - return average_precisions[average_precisions > -1].mean() - else: - num_preds_across_ranks = num_preds_per_class_across_ranks.sum(dim=1) - shapes_across_ranks_in_micro = [ - (*mean_dimensions_shape, num_preds_in_rank.item()) for num_preds_in_rank in num_preds_across_ranks - ] - TP_micro = torch.cat( - _all_gather_tensors_with_shapes( - self.concat_dict_of_tensor_lists_in_key_order(self._tp, num_classes).to(torch.uint8) - if num_preds_across_ranks[idist.get_rank()] - else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), - shapes_across_ranks_in_micro, - ), - dim=-1, - ).bool() - FP_micro = torch.cat( - _all_gather_tensors_with_shapes( - self.concat_dict_of_tensor_lists_in_key_order(self._fp, num_classes).to(torch.uint8) - if num_preds_across_ranks[idist.get_rank()] - else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), - shapes_across_ranks_in_micro, - ), - dim=-1, - ).bool() - scores_micro = torch.cat( - _all_gather_tensors_with_shapes( - self.concat_dict_of_tensor_lists_in_key_order( - cast(Dict[int, List[torch.Tensor]], self._scores), num_classes - ) - if num_preds_across_ranks[idist.get_rank()] - else torch.tensor([], dtype=torch.double, device=self._device), - num_preds_across_ranks.unsqueeze(dim=-1).tolist(), - ) - ) - P = P.sum() - recall, precision = self._compute_recall_and_precision(TP_micro, FP_micro, scores_micro, P) - return self._measure_average_precision(recall, precision).mean() - else: - rank_P = ( - torch.cat(cast(List[torch.Tensor], self._P), dim=-1) - if self._P - else ( - torch.empty((num_classes, 0), dtype=torch.uint8, device=self._device) - if self._type == "multilabel" - else torch.tensor( - [], dtype=torch.long if self._type == "multiclass" else torch.uint8, device=self._device - ) + if self._num_classes is None: + raise RuntimeError("Metric could not be computed without any update method call") + num_classes = cast(int, self._num_classes) + + rank_P = ( + torch.cat(self._P, dim=-1) + if self._P + else ( + torch.empty((num_classes, 0), dtype=torch.uint8, device=self._device) + if self._type == "multilabel" + else torch.tensor( + [], dtype=torch.long if self._type == "multiclass" else torch.uint8, device=self._device ) ) - P = torch.cat(cast(List[torch.Tensor], idist.all_gather(rank_P, tensor_different_shape=True)), dim=-1) - scores_classification = torch.cat( - cast( - List[torch.Tensor], - idist.all_gather( - torch.cat(cast(List[torch.Tensor], self._scores), dim=-1) - if self._scores - else ( - torch.tensor([], device=self._device) - if self._type == "binary" - else torch.empty((num_classes, 0), dtype=torch.double, device=self._device) - ), - tensor_different_shape=True, + ) + P = torch.cat(cast(List[torch.Tensor], idist.all_gather(rank_P, tensor_different_shape=True)), dim=-1) + scores_classification = torch.cat( + cast( + List[torch.Tensor], + idist.all_gather( + torch.cat(self._scores, dim=-1) + if self._scores + else ( + torch.tensor([], device=self._device) + if self._type == "binary" + else torch.empty((num_classes, 0), dtype=torch.double, device=self._device) ), + tensor_different_shape=True, ), - dim=-1, - ) - if self._type == "multiclass": - P = to_onehot(P, num_classes=self._num_classes).T - if self.class_mean == "micro": - P = P.reshape(1, -1) - scores_classification = scores_classification.view(1, -1) - P_count = P.sum(dim=-1) - average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double) - for cls in range(len(P_count)): - recall, precision = self._compute_recall_and_precision( - P[cls], None, scores_classification[cls], P_count[cls] - ) - average_precisions[cls] = self._measure_average_precision(recall, precision) - if self._type == "binary": - return average_precisions.item() - if self.class_mean is None: - return average_precisions - elif self.class_mean == "weighted": - return torch.sum(P_count * average_precisions) / P_count.sum() - else: - return average_precisions.mean() + ), + dim=-1, + ) + if self._type == "multiclass": + P = to_onehot(P, num_classes=num_classes).T + if self.class_mean == "micro": + P = P.reshape(1, -1) + scores_classification = scores_classification.view(1, -1) + P_count = P.sum(dim=-1) + average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double) + for cls in range(len(P_count)): + recall, precision = self._compute_recall_and_precision(P[cls], scores_classification[cls], P_count[cls]) + average_precisions[cls] = self._compute_average_precision(recall, precision) + if self._type == "binary": + return average_precisions.item() + if self.class_mean is None: + return average_precisions + elif self.class_mean == "weighted": + return torch.sum(P_count * average_precisions) / P_count.sum() + else: + return average_precisions.mean() diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index 1789a50f937..7023e7a8593 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -1,11 +1,22 @@ -from typing import Callable, cast, Dict, Literal, Optional, Sequence, Tuple, Union +from collections import defaultdict +from typing import Callable, cast, Dict, List, Literal, Optional, Sequence, Tuple, Union import torch -from ignite.metrics.mean_average_precision import MeanAveragePrecision +import ignite.distributed as idist +from ignite.distributed.utils import _all_gather_tensors_with_shapes +from ignite.metrics.mean_average_precision import _BaseMeanAveragePrecision +from ignite.metrics.metric import reinit__is_reduced +from ignite.metrics.recall import _BasePrecisionRecall + + +class ObjectDetectionMAP(_BaseMeanAveragePrecision): + _tp: Dict[int, List[torch.Tensor]] + _fp: Dict[int, List[torch.Tensor]] + _scores: Dict[int, List[torch.Tensor]] + _P: Dict[int, int] -class ObjectDetectionMAP(MeanAveragePrecision): def __init__( self, iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, @@ -16,32 +27,6 @@ def __init__( ) -> None: r"""Calculate the mean average precision for evaluating an object detector. - The input to metric's ``update`` method should be a binary tuple of str-to-tensor dictionaries, (y_pred, y), - which their items are as follows. N\ :sub:`det` and N\ :sub:`gt` are number of detections and ground truths - respectively. - - ======= ================== ================================================= - **y_pred items** - ------------------------------------------------------------------------------ - Key Value shape Description - ======= ================== ================================================= - 'bbox' (N\ :sub:`det`, 4) Bounding boxes of form (x1, y1, x2, y2) - containing top left and bottom right coordinates. - 'score' (N\ :sub:`det`,) Confidence score of detections. - 'label' (N\ :sub:`det`,) Predicted category number of detections. - ======= ================== ================================================= - - ========= ================== ================================================= - **y items** - ------------------------------------------------------------------------------ - Key Value shape Description - ========= ================== ================================================= - 'bbox' (N\ :sub:`gt`, 4) Bounding boxes of form (x1, y1, x2, y2) - containing top left and bottom right coordinates. - 'label' (N\ :sub:`gt`,) Category number of ground truths. - 'iscrowd' (N\ :sub:`gt`,) Whether ground truth boxes are crowd ones or not. - ========= ================== ================================================= - Args: iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. @@ -84,7 +69,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo if rec_thresholds is None: rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=torch.double) - super(ObjectDetectionMAP, self).__init__( + super().__init__( rec_thresholds=rec_thresholds, average="max-precision" if flavor == "COCO" else "precision", class_mean="with_other_dims", @@ -92,8 +77,36 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo device=device, ) + @reinit__is_reduced + def reset(self) -> None: + """ + Reset method of the metric + """ + super(_BasePrecisionRecall, self).reset() + + self._tp = defaultdict(lambda: []) + self._fp = defaultdict(lambda: []) + self._scores = defaultdict(lambda: []) + self._P = defaultdict(lambda: 0) + self._num_classes: int = 0 + + def _check_matching_input(self, output: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]) -> None: + y_pred_keys = {"bbox", "scores", "labels"} + if (output[0].keys() & y_pred_keys) != y_pred_keys: + raise ValueError( + "y_pred dict in update's input should have 'bbox', 'scores'" + f" and 'labels' keys. It has {output[0].keys()}" + ) + + y_keys = {"bbox", "labels"} + if (output[1].keys() & y_keys) != y_keys: + raise ValueError( + "y dict in update's input should have 'bbox', 'labels'" + f" and optionaly 'iscrowd' keys. It has {output[1].keys()}" + ) + def _compute_recall_and_precision( - self, TP: torch.Tensor, FP: Union[torch.Tensor, None], scores: torch.Tensor, P: torch.Tensor + self, TP: torch.Tensor, FP: torch.Tensor, scores: torch.Tensor, P: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Measuring recall & precision @@ -101,11 +114,29 @@ def _compute_recall_and_precision( same scores, they're considered associated with different thresholds in the course of measuring recall values, although it's not logically correct as those predictions are really associated with a single threshold, thus outputing a single recall value. + + Shape of function inputs and return values follow the table below. N\ :sub:`pred` is the number of detections + or predictions. ``...`` stands for the possible additional dimensions. Finally, \#unique scores represents + number of unique scores in ``scores`` which is actually the number of thresholds. + + ============== ====================== + **Object** **Shape** + ============== ====================== + TP and FP (..., N\ :sub:`pred`) + scores (N\ :sub:`pred`,) + P () (A single float, + greater than zero) + recall (..., \#unique scores) + precision (..., \#unique scores) + ============== ====================== + + Returns: + `(recall, precision)` """ indices = torch.argsort(scores, dim=-1, stable=True, descending=True) tp = TP[..., indices] tp_summation = tp.cumsum(dim=-1).double() - fp = cast(torch.Tensor, FP)[..., indices] + fp = FP[..., indices] fp_summation = fp.cumsum(dim=-1).double() recall = tp_summation / P @@ -114,7 +145,7 @@ def _compute_recall_and_precision( return recall, precision - def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: + def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: """Measuring average precision. This method is overriden since :math:`1/#recall_thresholds` is used instead of :math:`r_k - r_{k-1}` as the recall differential in COCO flavor. @@ -128,7 +159,7 @@ def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tens average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ if self.flavor != "COCO": - return super()._measure_average_precision(recall, precision) + return super()._compute_average_precision(recall, precision) precision_integrand = ( precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision @@ -140,16 +171,36 @@ def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tens ).where(rec_thresh_indices != recall.size(-1), 0) return torch.sum(precision_integrand, dim=-1) / len(cast(torch.Tensor, self.rec_thresholds)) - def compute(self) -> Union[torch.Tensor, float]: - if not sum(cast(Dict[int, int], self._P).values()) and self.flavor == "COCO": - return -1 - return super().compute() - def _do_matching( self, pred: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor] ) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]: - """ + r""" Matching logic of object detection mAP. + + The method returns a quadrople of dictionaries containing TP, FP, P (actual positive) counts and scores for + each class respectively. Please note that class numbers start from zero. + + Values in TP and FP are (m+1)-dimensional tensors of type ``uint8`` and shape + (D\ :sub:`1`, D\ :sub:`2`, ..., D\ :sub:`m`, n\ :sub:`cls`) in which D\ :sub:`i`\ 's are possible additional + dimensions (excluding the class dimension) mean of the average precision is taken over. n\ :sub:`cls` is the + number of predictions for class `cls` which is the same for TP and FP. + + Note: + TP and FP values are stored as uint8 tensors internally to avoid bool-to-uint8 copies before collective + operations, as PyTorch colective operations `do not `_ + support boolean tensors, at least on Gloo backend. + + P counts contains the number of ground truth samples for each class. Finally, the values in scores are 1-dim + tensors of shape (n\ :sub:`cls`,) containing score or confidence of the predictions (doesn't need to be in + [0,1]). If there is no prediction or ground truth for a class, it is absent from (TP, FP, scores) and P + dictionaries respectively. + + Args: + pred: First member of :meth:`update`'s input is given as this argument. + target: Second member of :meth:`update`'s input is given as this argument. + + Returns: + `(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores. """ labels = target["labels"].detach() pred_labels = pred["labels"].detach() @@ -159,7 +210,7 @@ def _do_matching( pred_boxes = pred["bbox"] gt_boxes = target["bbox"] - is_crowd = target["iscrowd"] + is_crowd = target["iscrowd"] if "iscrowd" in target else torch.zeros_like(target["labels"], dtype=torch.bool) tp: Dict[int, torch.Tensor] = {} fp: Dict[int, torch.Tensor] = {} @@ -227,3 +278,133 @@ def _do_matching( fp[category] = category_fp return tp, fp, P, scores + + @reinit__is_reduced + def update(self, output: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]) -> None: + r"""Metric update function using prediction and target. + + Args: + output: a binary tuple of str-to-tensor dictionaries, (y_pred, y), which their items + are as follows. N\ :sub:`det` and N\ :sub:`gt` are number of detections and + ground truths respectively. + + ======= ================== ================================================= + **y_pred items** + ------------------------------------------------------------------------------ + Key Value shape Description + ======= ================== ================================================= + 'bbox' (N\ :sub:`det`, 4) Bounding boxes of form (x1, y1, x2, y2) + containing top left and bottom right coordinates. + 'score' (N\ :sub:`det`,) Confidence score of detections. + 'label' (N\ :sub:`det`,) Predicted category number of detections. + ======= ================== ================================================= + + ========= ================== ================================================= + **y items** + ------------------------------------------------------------------------------ + Key Value shape Description + ========= ================== ================================================= + 'bbox' (N\ :sub:`gt`, 4) Bounding boxes of form (x1, y1, x2, y2) + containing top left and bottom right coordinates. + 'label' (N\ :sub:`gt`,) Category number of ground truths. + 'iscrowd' (N\ :sub:`gt`,) Whether ground truth boxes are crowd ones or not. + It's optional with default value of ``False``. + ========= ================== ================================================= + """ + self._check_matching_input(output) + tps, fps, ps, scores_dict = self._do_matching(output[0], output[1]) + for cls in tps: + self._tp[cls].append(tps[cls].to(device=self._device, dtype=torch.uint8)) + self._fp[cls].append(fps[cls].to(device=self._device, dtype=torch.uint8)) + self._scores[cls].append(scores_dict[cls].to(self._device)) + for cls in ps: + self._P[cls] += ps[cls] + classes = tps.keys() | ps.keys() + if classes: + self._num_classes = max(max(classes) + 1, self._num_classes) + + def compute(self) -> Union[torch.Tensor, float]: + """ + Compute method of the metric + """ + if sum(self._P.values()) < 1 and self.flavor == "COCO": + return -1 + + num_classes = int(idist.all_reduce(self._num_classes or 0, "MAX")) + if num_classes < 1: + return 0.0 + + P = cast( + torch.Tensor, + idist.all_reduce(torch.tensor(list(map(self._P.__getitem__, range(num_classes))), device=self._device)), + ) + num_preds = torch.tensor( + [sum([tp.shape[-1] for tp in self._tp[cls]]) if self._tp[cls] else 0 for cls in range(num_classes)], + device=self._device, + ) + num_preds_per_class_across_ranks = torch.stack( + cast(torch.Tensor, idist.all_gather(num_preds)).split(split_size=num_classes) + ) + if num_preds_per_class_across_ranks.sum() == 0: + return 0.0 + a_nonempty_rank, its_class_with_pred = list(zip(*torch.where(num_preds_per_class_across_ranks != 0))).pop(0) + a_nonempty_rank = a_nonempty_rank.item() + its_class_with_pred = its_class_with_pred.item() + mean_dimensions_shape = cast( + torch.Tensor, + idist.broadcast( + torch.tensor(self._tp[its_class_with_pred][-1].shape[:-1], device=self._device) + if idist.get_rank() == a_nonempty_rank + else None, + a_nonempty_rank, + safe_mode=True, + ), + ).tolist() + + average_precisions = -torch.ones( + (num_classes, *mean_dimensions_shape), + device=self._device, + dtype=torch.double, + ) + for cls in range(num_classes): + if P[cls] == 0: + continue + + num_preds_across_ranks = num_preds_per_class_across_ranks[:, [cls]] + if num_preds_across_ranks.sum() == 0: + average_precisions[cls] = 0 + continue + shape_across_ranks = [ + (*mean_dimensions_shape, num_pred_in_rank.item()) for num_pred_in_rank in num_preds_across_ranks + ] + TP = torch.cat( + _all_gather_tensors_with_shapes( + torch.cat(self._tp[cls], dim=-1) + if self._tp[cls] + else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), + shape_across_ranks, + ), + dim=-1, + ) + FP = torch.cat( + _all_gather_tensors_with_shapes( + torch.cat(self._fp[cls], dim=-1) + if self._fp[cls] + else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), + shape_across_ranks, + ), + dim=-1, + ) + scores = torch.cat( + _all_gather_tensors_with_shapes( + torch.cat(self._scores[cls]) + if self._scores[cls] + else torch.tensor([], dtype=torch.double, device=self._device), + num_preds_across_ranks.tolist(), + ) + ) + recall, precision = self._compute_recall_and_precision(TP, FP, scores, P[cls]) + average_precision_for_cls_across_other_dims = self._compute_average_precision(recall, precision) + average_precisions[cls] = average_precision_for_cls_across_other_dims + + return average_precisions[average_precisions > -1].mean() diff --git a/tests/ignite/metrics/test_mean_average_precision.py b/tests/ignite/metrics/test_mean_average_precision.py index 8b4c9def745..3125aeb13e0 100644 --- a/tests/ignite/metrics/test_mean_average_precision.py +++ b/tests/ignite/metrics/test_mean_average_precision.py @@ -1,5 +1,3 @@ -from typing import Tuple - import numpy as np import pytest import torch @@ -29,13 +27,16 @@ def test_wrong_input(): with pytest.raises(ValueError, match="rec_thresholds values should be between 0 and 1"): MeanAveragePrecision(rec_thresholds=(0.0, 0.5, 1.0, 1.5)) - with pytest.raises(ValueError, match="class_mean 'with_other_dims' is not compatible with classification"): + with pytest.raises(ValueError, match="class_mean 'with_other_dims' is not compatible with this class"): MeanAveragePrecision(class_mean="with_other_dims") + metric = MeanAveragePrecision() + with pytest.raises(RuntimeError, match="Metric could not be computed without any update method call"): + metric.compute() + def test_wrong_classification_input(): metric = MeanAveragePrecision() - assert metric._task == "classification" with pytest.raises(TypeError, match="`y_pred` should be a float tensor"): metric.update((torch.tensor([0, 1, 0]), torch.tensor([1, 0, 1]))) @@ -48,63 +49,19 @@ def test_wrong_classification_input(): metric.update((torch.tensor([[0.5, 0.4, 0.1]]), torch.tensor([3]))) -class Dummy_mAP(MeanAveragePrecision): - def _do_matching(self, pred: Tuple, target: Tuple): - return *pred, *target - - -def test_wrong__do_matching(): - metric = MeanAveragePrecision() - with pytest.raises(NotImplementedError, match="Please subclass MeanAveragePrecision and implement"): - metric._do_matching(None, None) - - metric = Dummy_mAP() - - with pytest.raises(ValueError, match="Returned TP, FP and scores dictionaries from _do_matching should have"): - metric.update( - ( - ({1: torch.tensor([True])}, {1: torch.tensor([False])}), - ({1: 1}, {1: torch.tensor([0.8]), 2: torch.tensor([0.9])}), - ) - ) - - with pytest.raises(TypeError, match="Tensors in TP and FP dictionaries should be boolean or uint8"): - metric.update((({1: torch.tensor([1])}, {1: torch.tensor([False])}), ({1: 1}, {1: torch.tensor([0.8])}))) - - with pytest.raises( - ValueError, match="Sample dimension of tensors in TP, FP and scores should have equal size per class" - ): - metric.update( - (({1: torch.tensor([True])}, {1: torch.tensor([False, False])}), ({1: 1}, {1: torch.tensor([0.8])})) - ) - - metric.update((({1: torch.tensor([True])}, {1: torch.tensor([False])}), ({1: 1}, {1: torch.tensor([0.8])}))) - with pytest.raises(ValueError, match="Tensors in returned FP from _do_matching should not change in shape except"): - metric.update( - ( - ({1: torch.tensor([False, True])}, {1: torch.tensor([[True, False], [False, False]])}), - ({1: 1}, {1: torch.tensor([0.8, 0.9])}), - ) - ) - - -def test__classification_prepare_output(): +def test__prepare_output(): metric = MeanAveragePrecision() metric._type = "binary" - scores, y = metric._classification_prepare_output( - torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool() - ) + scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool())) assert scores.shape == y.shape == (1, 120) metric._type = "multiclass" - scores, y = metric._classification_prepare_output(torch.rand((5, 4, 3, 2)), torch.randint(0, 4, (5, 3, 2))) + scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 4, (5, 3, 2)))) assert scores.shape == (4, 30) and y.shape == (30,) metric._type = "multilabel" - scores, y = metric._classification_prepare_output( - torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool() - ) + scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool())) assert scores.shape == y.shape == (4, 30) @@ -114,73 +71,29 @@ def test_update(): metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4)).bool())) assert len(metric._scores) == len(metric._P) == 1 - metric = Dummy_mAP() - assert len(metric._tp) == len(metric._fp) == len(metric._scores) == len(metric._P) == metric._num_classes == 0 - - metric.update((({1: torch.tensor([True])}, {1: torch.tensor([False])}), ({1: 1, 2: 1}, {1: torch.tensor([0.8])}))) - assert len(metric._tp[1]) == len(metric._fp[1]) == len(metric._scores[1]) == 1 - assert len(metric._P) == 2 and metric._P[2] == 1 - assert metric._num_classes == 3 - - metric.update((({}, {}), ({2: 2}, {}))) - assert metric._P[2] == 3 - def test__compute_recall_and_precision(): - # Classification m = MeanAveragePrecision() scores = torch.rand((50,)) y_true = torch.randint(0, 2, (50,)).bool() precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) - FP = None P = y_true.sum(dim=-1) - ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, FP, scores, P) + ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P) assert (ignite_recall.squeeze().flip(0).numpy() == recall[:-1]).all() assert (ignite_precision.squeeze().flip(0).numpy() == precision[:-1]).all() - # Classification, when there's no actual positive. Numpy expectedly raises warning. + # When there's no actual positive. Numpy expectedly raises warning. scores = torch.rand((50,)) y_true = torch.zeros((50,)).bool() precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) - FP = None P = torch.tensor(0) - ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, FP, scores, P) + ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P) assert (ignite_recall.flip(0).numpy() == recall[:-1]).all() assert (ignite_precision.flip(0).numpy() == precision[:-1]).all() - # Detection, in the case detector detects all gt objects but also produces some wrong predictions. - scores = torch.rand((50,)) - y_true = torch.randint(0, 2, (50,)) - m = Dummy_mAP() - - ignite_recall, ignite_precision = m._compute_recall_and_precision( - y_true.bool(), ~(y_true.bool()), scores, y_true.sum() - ) - sklearn_precision, sklearn_recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) - assert (ignite_recall.flip(0).numpy() == sklearn_recall[:-1]).all() - assert (ignite_precision.flip(0).numpy() == sklearn_precision[:-1]).all() - - # Detection like above but with two additional mean dimensions. - scores = torch.rand((50,)) - y_true = torch.zeros((6, 8, 50)) - sklearn_precisions, sklearn_recalls = [], [] - for i in range(6): - for j in range(8): - y_true[i, j, np.random.choice(50, size=15, replace=False)] = 1 - precision, recall, _ = precision_recall_curve(y_true[i, j].numpy(), scores.numpy()) - sklearn_precisions.append(precision[:-1]) - sklearn_recalls.append(recall[:-1]) - sklearn_precisions = np.array(sklearn_precisions).reshape(6, 8, -1) - sklearn_recalls = np.array(sklearn_recalls).reshape(6, 8, -1) - ignite_recall, ignite_precision = m._compute_recall_and_precision( - y_true.bool(), ~(y_true.bool()), scores, torch.tensor(15) - ) - assert (ignite_recall.flip(-1).numpy() == sklearn_recalls).all() - assert (ignite_precision.flip(-1).numpy() == sklearn_precisions).all() - -def test__measure_average_precision(): +def test__compute_average_precision(): m = MeanAveragePrecision() # Binary data @@ -188,7 +101,7 @@ def test__measure_average_precision(): y_true = np.random.randint(0, 2, 50) ap = average_precision_score(y_true, scores) precision, recall, _ = precision_recall_curve(y_true, scores) - ignite_ap = m._measure_average_precision( + ignite_ap = m._compute_average_precision( torch.from_numpy(recall[:-1]).flip(-1), torch.from_numpy(precision[:-1]).flip(-1) ) assert np.allclose(ignite_ap.item(), ap) @@ -201,7 +114,7 @@ def test__measure_average_precision(): for cls in range(scores.shape[1]): precision, recall, _ = precision_recall_curve(y_true[:, cls], scores[:, cls]) ignite_ap.append( - m._measure_average_precision( + m._compute_average_precision( torch.from_numpy(recall[:-1]).flip(-1), torch.from_numpy(precision[:-1]).flip(-1) ).item() ) @@ -209,7 +122,7 @@ def test__measure_average_precision(): assert np.allclose(ignite_ap, ap) -def test_compute_classification_binary_data(): +def test_compute_binary_data(): m = MeanAveragePrecision() scores = torch.rand((130,)) y_true = torch.randint(0, 2, (130,)) @@ -224,7 +137,7 @@ def test_compute_classification_binary_data(): @pytest.mark.parametrize("class_mean", [None, "macro", "micro", "weighted"]) -def test_compute_classification_nonbinary_data(class_mean): +def test_compute_nonbinary_data(class_mean): scores = torch.rand((130, 5, 2, 2)) sklearn_scores = scores.transpose(1, -1).reshape(-1, 5).numpy() @@ -241,7 +154,7 @@ def test_compute_classification_nonbinary_data(class_mean): assert np.allclose(sklearn_map, ignite_map) # Multilabel - m = MeanAveragePrecision(classification_is_multilabel=True, class_mean=class_mean) + m = MeanAveragePrecision(is_multilabel=True, class_mean=class_mean) y_true = torch.randint(0, 2, (130, 5, 2, 2)).bool() m.update((scores[:50], y_true[:50])) m.update((scores[50:], y_true[50:])) @@ -253,73 +166,8 @@ def test_compute_classification_nonbinary_data(class_mean): assert np.allclose(sklearn_map, ignite_map) -@pytest.mark.parametrize("class_mean", ["macro", None, "micro", "weighted", "with_other_dims"]) -def test_compute_detection(class_mean): - m = Dummy_mAP(class_mean=class_mean) - - # The case in which, detector detects all gt objects but also produces some wrong predictions. Also classes - # have the same number of predictions. - - y_true = torch.randint(0, 2, (40, 5)) - scores = torch.rand((40, 5)) - - for s in [slice(20), slice(20, 40)]: - tp = {c: y_true[s, c].bool() for c in range(5)} - fp = {c: ~(y_true[s, c].bool()) for c in range(5)} - p = dict(enumerate(y_true[s].sum(dim=0).tolist())) - score = {c: scores[s, c] for c in range(5)} - m.update(((tp, fp), (p, score))) - - ignite_map = m.compute().numpy() - - sklearn_class_mean = class_mean if class_mean != "with_other_dims" else "macro" - sklearn_map = average_precision_score(y_true.numpy(), scores.numpy(), average=sklearn_class_mean) - assert np.allclose(sklearn_map, ignite_map) - - # Like above but with two additional mean dimensions. - m.reset() - y_true = torch.zeros((5, 6, 8, 50)) - scores = torch.rand((50, 5)) - P_counts = np.random.choice(50, size=5) - sklearn_aps = [] - for c in range(5): - for i in range(6): - for j in range(8): - y_true[c, i, j, np.random.choice(50, size=P_counts[c], replace=False)] = 1 - if class_mean != "micro": - sklearn_aps.append( - average_precision_score( - y_true[c].view(6 * 8, 50).T.numpy(), scores[:, c].repeat(6 * 8, 1).T.numpy(), average=None - ) - ) - if class_mean == "micro": - sklearn_aps = average_precision_score( - torch.cat(y_true.unbind(0), dim=-1).view(6 * 8, 5 * 50).T.numpy(), - scores.T.reshape(5 * 50).repeat(6 * 8, 1).T.numpy(), - average=None, - ) - sklearn_aps = np.array(sklearn_aps) - if class_mean in (None, "micro"): - sklearn_map = sklearn_aps.mean(axis=-1) - elif class_mean == "macro": - sklearn_map = sklearn_aps.mean(axis=-1)[P_counts != 0].mean() - elif class_mean == "with_other_dims": - sklearn_map = sklearn_aps[P_counts != 0].mean() - else: - sklearn_map = np.dot(P_counts, sklearn_aps.mean(axis=-1)) / P_counts.sum() - - for s in [slice(0, 20), slice(20, 50)]: - tp = {c: y_true[c, :, :, s].bool() for c in range(5)} - fp = {c: ~(y_true[c, :, :, s].bool()) for c in range(5)} - p = dict(enumerate(y_true[:, 0, 0, s].sum(dim=-1).tolist())) - score = {c: scores[s, c] for c in range(5)} - m.update(((tp, fp), (p, score))) - ignite_map = m.compute().numpy() - assert np.allclose(ignite_map, sklearn_map) - - @pytest.mark.parametrize("data_type", ["binary", "multiclass", "multilabel"]) -def test_distrib_integration_classification(distributed, data_type): +def test_distrib_integration(distributed, data_type): rank = idist.get_rank() world_size = idist.get_world_size() device = idist.device() @@ -332,7 +180,7 @@ def update(_, i): ) engine = Engine(update) - mAP = MeanAveragePrecision(classification_is_multilabel=data_type == "multilabel", device=metric_device) + mAP = MeanAveragePrecision(is_multilabel=data_type == "multilabel", device=metric_device) mAP.attach(engine, "mAP") y_true_size = (10 * 2 * world_size, 3, 2) if data_type != "multilabel" else (10 * 2 * world_size, 4, 3, 2) @@ -361,105 +209,3 @@ def update(_, i): metric_devices.append(idist.device()) for metric_device in metric_devices: _test(metric_device) - - -@pytest.mark.parametrize("class_mean", [None, "macro", "micro", "weighted", "with_other_dims"]) -def test_distrib_integration_detection(distributed, class_mean): - rank = idist.get_rank() - device = idist.device() - world_size = idist.get_world_size() - - def _test(metric_device): - def update(_, i): - y_true_batch = y_true[..., (2 * rank + i) * 10 : (2 * rank + i + 1) * 10] - scores_batch = scores[..., (2 * rank + i) * 10 : (2 * rank + i + 1) * 10] - return ( - ({c: y_true_batch[c].bool() for c in range(4)}, {c: ~(y_true_batch[c].bool()) for c in range(4)}), - ( - dict( - enumerate( - (y_true_batch[:, 0, 0] if y_true_batch.ndim == 4 else y_true_batch).sum(dim=-1).tolist() - ) - ), - {c: scores_batch[c] for c in range(4)}, - ), - ) - - engine = Engine(update) - # The case in which, detector detects all gt objects but also produces some wrong predictions. Also classes - # have the same number of predictions. - mAP = Dummy_mAP(device=metric_device, class_mean=class_mean) - mAP.attach(engine, "mAP") - - y_true = torch.randint(0, 2, size=(4, 10 * 2 * world_size)).to(device) - scores = torch.rand((4, 10 * 2 * world_size)).to(device) - engine.run(range(2), max_epochs=1) - assert "mAP" in engine.state.metrics - sklearn_class_mean = class_mean if class_mean != "with_other_dims" else "macro" - sklearn_map = average_precision_score(y_true.T.numpy(), scores.T.numpy(), average=sklearn_class_mean) - assert np.allclose(sklearn_map, engine.state.metrics["mAP"]) - - # Like above but with two additional mean dimensions. - y_true = torch.zeros((4, 6, 8, 10 * 2 * world_size)) - - P_counts = np.random.choice(10 * 2 * world_size, size=4) - sklearn_aps = [] - for c in range(4): - for i in range(6): - for j in range(8): - y_true[c, i, j, np.random.choice(10 * 2 * world_size, size=P_counts[c], replace=False)] = 1 - if class_mean != "micro": - sklearn_aps.append( - average_precision_score( - y_true[c].view(6 * 8, 10 * 2 * world_size).T.numpy(), - scores[c].repeat(6 * 8, 1).T.numpy(), - average=None, - ) - ) - if class_mean == "micro": - sklearn_aps = average_precision_score( - torch.cat(y_true.unbind(0), dim=-1).view(6 * 8, 4 * 10 * 2 * world_size).T.numpy(), - scores.reshape(4 * 10 * 2 * world_size).repeat(6 * 8, 1).T.numpy(), - average=None, - ) - sklearn_aps = np.array(sklearn_aps) - if class_mean in (None, "micro"): - sklearn_map = sklearn_aps.mean(axis=-1) - elif class_mean == "macro": - sklearn_map = sklearn_aps.mean(axis=-1)[P_counts != 0].mean() - elif class_mean == "with_other_dims": - sklearn_map = sklearn_aps[P_counts != 0].mean() - else: - sklearn_map = np.dot(P_counts, sklearn_aps.mean(axis=-1)) / P_counts.sum() - - engine.run(range(2), max_epochs=1) - - assert np.allclose(sklearn_map, engine.state.metrics["mAP"]) - - metric_devices = [torch.device("cpu")] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - _test(metric_device) - - -# class MatchFirstDetectionFirst_mAP(MeanAveragePrecision): -# def _do_matching(self, pred: Tuple[Sequence[int], Sequence[float]] , target: Sequence[int]): -# P = dict(Counter(target)) -# tp = defaultdict(lambda: []) -# scores = defaultdict(lambda: []) - -# target = torch.tensor(target) -# matched = torch.zeros((len(target),)).bool() -# for label, score in zip(*pred): -# try: -# matched[torch.logical_and(target == label, ~matched).tolist().index(True)] = True -# tp[label].append(True) -# except ValueError: -# tp[label].append(False) -# scores[label].append(score) - -# tp = {label: torch.tensor(_tp) for label, _tp in tp.items()} -# fp = {label: ~_tp for label, _tp in tp.items()} -# scores = {label: torch.tensor(_scores) for label, _scores in scores.items()} -# return tp, fp, P, scores diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 8e95630b1ee..2ea6675d552 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -627,6 +627,12 @@ def test_wrong_input(): with pytest.raises(ValueError, match="Currently, the only available flavor for ObjectDetectionMAP is 'COCO'"): ObjectDetectionMAP(flavor="wrong flavor") + m = ObjectDetectionMAP() + with pytest.raises(ValueError, match="y_pred dict in update's input should have 'bbox', 'scores'"): + m.update(({"bbox": None, "scores": None}, {"bbox": None, "labels": None})) + with pytest.raises(ValueError, match="y dict in update's input should have 'bbox', 'labels'"): + m.update(({"bbox": None, "scores": None, "labels": None}, {"labels": None})) + def test_empty_data(): """ @@ -727,6 +733,54 @@ def test_iou_thresholding(): assert (metric._tp[1][1] == torch.tensor([[True], [False], [False], [False]])).all() +def test__do_matching_output(sample): + metric = ObjectDetectionMAP() + + for pred, target in zip(*sample.data): + tps, fps, _, scores = metric._do_matching(pred, target) + + assert tps.keys() == fps.keys() == scores.keys() + + try: + cls = list(tps.keys()).pop() + except IndexError: # No prediction + pass + else: + assert tps[cls].dtype in (torch.bool, torch.uint8) + assert tps[cls].size(-1) == fps[cls].size(-1) == scores[cls].size(0) + + metric.update((pred, target)) + + for metric_tp_or_fp, new_tp_or_fp in [(metric._tp, tps), (metric._fp, fps)]: + try: + cls = (metric_tp_or_fp.keys() & new_tp_or_fp.keys()).pop() + except KeyError: + pass + else: + assert metric_tp_or_fp[cls][-1].shape[:-1] == new_tp_or_fp[cls].shape[:-1] + + +class Dummy_mAP(ObjectDetectionMAP): + def _do_matching(self, pred: Tuple, target: Tuple): + return *pred, *target + + def _check_matching_input(self, output: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]): + pass + + +def test_update(): + metric = Dummy_mAP() + assert len(metric._tp) == len(metric._fp) == len(metric._scores) == len(metric._P) == metric._num_classes == 0 + + metric.update((({1: torch.tensor([True])}, {1: torch.tensor([False])}), ({1: 1, 2: 1}, {1: torch.tensor([0.8])}))) + assert len(metric._tp[1]) == len(metric._fp[1]) == len(metric._scores[1]) == 1 + assert len(metric._P) == 2 and metric._P[2] == 1 + assert metric._num_classes == 3 + + metric.update((({}, {}), ({2: 2}, {}))) + assert metric._P[2] == 3 + + def test_matching(): """ PyCOCO matching rules: @@ -823,7 +877,7 @@ def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold def test__compute_recall_and_precision(): - # Detection, in the case detector detects all gt objects but also produces some wrong predictions. + # The case in which detector detects all gt objects but also produces some wrong predictions. scores = torch.rand((50,)) y_true = torch.randint(0, 2, (50,)) m = ObjectDetectionMAP() @@ -837,7 +891,7 @@ def test__compute_recall_and_precision(): assert (ignite_recall.flip(0).numpy() == sklearn_recall[:-1]).all() assert (ignite_precision.flip(0).numpy() == sklearn_precision[:-1]).all() - # Detection like above but with two additional mean dimensions. + # Like above but with two additional mean dimensions. scores = torch.rand((50,)) y_true = torch.zeros((6, 8, 50)) sklearn_precisions, sklearn_recalls = [], [] @@ -861,7 +915,6 @@ def test__compute_recall_and_precision(): def test_compute(sample): device = idist.device() metric_50 = ObjectDetectionMAP(iou_thresholds=[0.5], device=device) - assert metric_50._task == "detection" metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=device) metric_50_95 = ObjectDetectionMAP(device=device) From 65cdd08c503b2ea74caf5a7e822262e42db52383 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Sat, 17 Jun 2023 09:13:50 +0330 Subject: [PATCH 08/41] Remove all_gather with different shape --- ignite/distributed/utils.py | 28 ++----------- ignite/metrics/mean_average_precision.py | 39 ++++++++++--------- ignite/metrics/vision/object_detection_map.py | 10 ++--- tests/ignite/distributed/utils/__init__.py | 16 -------- 4 files changed, 29 insertions(+), 64 deletions(-) diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index fc4cea80d88..8340c631306 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -351,7 +351,7 @@ def all_reduce( return _model.all_reduce(tensor, op, group=group) -def _all_gather_tensors_with_shapes( +def all_gather_tensors_with_shapes( tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None ) -> List[torch.Tensor]: if _need_to_sync and isinstance(_model, _SerialModel): @@ -381,34 +381,22 @@ def _all_gather_tensors_with_shapes( def all_gather( - tensor: Union[torch.Tensor, float, str], - group: Optional[Union[Any, List[int]]] = None, - tensor_different_shape: bool = False, -) -> Union[torch.Tensor, float, List[float], List[str], List[torch.Tensor]]: + tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None +) -> Union[torch.Tensor, float, List[float], List[str]]: """Helper method to perform all gather operation. Args: tensor: tensor or number or str to collect across participating processes. If tensor, it should have - the same number of dimensions across processes. + the same shape across processes. group: list of integer or the process group for each backend. If None, the default process group will be used. - tensor_different_shape: If True, it accounts for difference in input shape across processes. In this case, it - induces more collective operations. If False, `tensor` should have the same shape across processes. - Ignored when `tensor` is not a tensor. Default False. - Returns: If input is a tensor, returns a torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` - if ``tensor_different_shape = False``, otherwise a list of tensors with length ``world_size``(if ``group`` - is `None`) or `len(group)`. If current process does not belong to `group`, a list with `tensor` as its only - item is retured. If input is a number, a torch.Tensor of shape ``(world_size, )`` is returned and finally a list of strings is returned if input is a string. .. versionchanged:: 0.4.11 added ``group`` - - .. versionchanged:: 0.5.1 - added ``tensor_different_shape`` """ if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) @@ -416,14 +404,6 @@ def all_gather( if isinstance(group, list) and all(isinstance(item, int) for item in group): group = _model.new_group(group) - if isinstance(tensor, torch.Tensor) and tensor_different_shape: - if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group): - return [tensor] - all_shapes: torch.Tensor = _model.all_gather(torch.tensor(tensor.shape), group=group).view( - -1, len(tensor.shape) - ) - return _all_gather_tensors_with_shapes(tensor, all_shapes.tolist(), group=group) - return _model.all_gather(tensor, group=group) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index 6bb75126bd9..d9be7e49593 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -5,6 +5,7 @@ from typing_extensions import Literal import ignite.distributed as idist +from ignite.distributed.utils import all_gather_tensors_with_shapes from ignite.metrics.metric import reinit__is_reduced from ignite.metrics.recall import _BasePrecisionRecall from ignite.utils import to_onehot @@ -390,7 +391,7 @@ def compute(self) -> Union[torch.Tensor, float]: """ if self._num_classes is None: raise RuntimeError("Metric could not be computed without any update method call") - num_classes = cast(int, self._num_classes) + num_classes = self._num_classes rank_P = ( torch.cat(self._P, dim=-1) @@ -403,32 +404,32 @@ def compute(self) -> Union[torch.Tensor, float]: ) ) ) - P = torch.cat(cast(List[torch.Tensor], idist.all_gather(rank_P, tensor_different_shape=True)), dim=-1) - scores_classification = torch.cat( - cast( - List[torch.Tensor], - idist.all_gather( - torch.cat(self._scores, dim=-1) - if self._scores - else ( - torch.tensor([], device=self._device) - if self._type == "binary" - else torch.empty((num_classes, 0), dtype=torch.double, device=self._device) - ), - tensor_different_shape=True, - ), - ), - dim=-1, + rank_P_shapes = cast(torch.Tensor, idist.all_gather(torch.tensor(rank_P.shape))).view(-1, len(rank_P.shape)) + P = torch.cat(all_gather_tensors_with_shapes(rank_P, rank_P_shapes.tolist()), dim=-1) + + rank_scores = ( + torch.cat(self._scores, dim=-1) + if self._scores + else ( + torch.tensor([], device=self._device) + if self._type == "binary" + else torch.empty((num_classes, 0), dtype=torch.double, device=self._device) + ) ) + rank_scores_shapes = cast(torch.Tensor, idist.all_gather(torch.tensor(rank_scores.shape))).view( + -1, len(rank_scores.shape) + ) + scores = torch.cat(all_gather_tensors_with_shapes(rank_scores, rank_scores_shapes.tolist()), dim=-1) + if self._type == "multiclass": P = to_onehot(P, num_classes=num_classes).T if self.class_mean == "micro": P = P.reshape(1, -1) - scores_classification = scores_classification.view(1, -1) + scores = scores.view(1, -1) P_count = P.sum(dim=-1) average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double) for cls in range(len(P_count)): - recall, precision = self._compute_recall_and_precision(P[cls], scores_classification[cls], P_count[cls]) + recall, precision = self._compute_recall_and_precision(P[cls], scores[cls], P_count[cls]) average_precisions[cls] = self._compute_average_precision(recall, precision) if self._type == "binary": return average_precisions.item() diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index 7023e7a8593..16d5bd42680 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -4,7 +4,7 @@ import torch import ignite.distributed as idist -from ignite.distributed.utils import _all_gather_tensors_with_shapes +from ignite.distributed.utils import all_gather_tensors_with_shapes from ignite.metrics.mean_average_precision import _BaseMeanAveragePrecision from ignite.metrics.metric import reinit__is_reduced @@ -20,7 +20,7 @@ class ObjectDetectionMAP(_BaseMeanAveragePrecision): def __init__( self, iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - flavor: Optional[Literal["COCO",]] = "COCO", + flavor: Optional["Literal['COCO']"] = "COCO", rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), @@ -378,7 +378,7 @@ def compute(self) -> Union[torch.Tensor, float]: (*mean_dimensions_shape, num_pred_in_rank.item()) for num_pred_in_rank in num_preds_across_ranks ] TP = torch.cat( - _all_gather_tensors_with_shapes( + all_gather_tensors_with_shapes( torch.cat(self._tp[cls], dim=-1) if self._tp[cls] else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), @@ -387,7 +387,7 @@ def compute(self) -> Union[torch.Tensor, float]: dim=-1, ) FP = torch.cat( - _all_gather_tensors_with_shapes( + all_gather_tensors_with_shapes( torch.cat(self._fp[cls], dim=-1) if self._fp[cls] else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), @@ -396,7 +396,7 @@ def compute(self) -> Union[torch.Tensor, float]: dim=-1, ) scores = torch.cat( - _all_gather_tensors_with_shapes( + all_gather_tensors_with_shapes( torch.cat(self._scores[cls]) if self._scores[cls] else torch.tensor([], dtype=torch.double, device=self._device), diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index d2a9f177d9f..040954a2795 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -192,14 +192,6 @@ def _test_distrib_all_gather(device): true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1) assert (res == true_res).all() - t = torch.full((rank + 1, (rank + 1) * 2, idist.get_world_size() - rank), rank) - in_dtype = t.dtype - res = idist.all_gather(t, tensor_different_shape=True) - assert res[rank].shape == (rank + 1, (rank + 1) * 2, idist.get_world_size() - rank) - assert type(res) == list and res[0].dtype == in_dtype - for i in range(idist.get_world_size()): - assert (res[i] == torch.full((i + 1, (i + 1) * 2, idist.get_world_size() - i), i)).all() - if idist.get_world_size() > 1: with pytest.raises(TypeError, match=r"Unhandled input type"): idist.all_reduce([0, 1, 2]) @@ -228,14 +220,6 @@ def _test_distrib_all_gather_group(device): res = idist.all_gather(t, group=ranks) assert torch.equal(res, torch.tensor(ranks, device=device)) - t = torch.tensor([rank], device=device) - if bnd not in ("horovod"): - res = idist.all_gather(t, group=ranks, tensor_different_shape=True) - if rank not in ranks: - assert res == [t] - else: - assert torch.equal(res[rank], torch.tensor(ranks, device=device)) - if bnd in ("nccl", "gloo", "mpi"): with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"): res = idist.all_gather(t, group="abc") From aac2e558420cde480b284318711ce1a42e96118e Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 21 Jun 2023 12:46:04 +0330 Subject: [PATCH 09/41] Add test for all_gather_with_different_shape func --- ignite/distributed/comp_models/native.py | 7 ++++--- ignite/distributed/utils.py | 4 +++- tests/ignite/distributed/utils/__init__.py | 17 ++++++++++++++++- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index e1b783f7159..ee0e22073ca 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -436,14 +436,15 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t group_size = self.get_world_size() elif isinstance(group, dist.ProcessGroup): group_size = group.size() - elif isinstance(group, list): - group_size = len(group) else: raise ValueError("Argument group should be list of int or ProcessGroup") if tensor.ndimension() == 0: tensor = tensor.unsqueeze(0) output = [torch.zeros_like(tensor) for _ in range(group_size)] - dist.all_gather(output, tensor, group=group) + if group is None: + dist.all_gather(output, tensor) + else: + dist.all_gather(output, tensor, group=group) return torch.cat(output, dim=0) def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any: diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index b683a1fe21b..cccb31e40b2 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -5,6 +5,7 @@ from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union import torch +from torch import distributed as dist from ignite.distributed.comp_models import ( _SerialModel, @@ -44,6 +45,7 @@ "one_rank_only", "new_group", "one_rank_first", + "all_gather_tensors_with_shapes", ] _model = _SerialModel() @@ -360,7 +362,7 @@ def all_gather_tensors_with_shapes( if isinstance(group, list) and all(isinstance(item, int) for item in group): group = _model.new_group(group) - if isinstance(_model, _SerialModel) or (group is not None and _model.get_rank() not in group): + if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER: return [tensor] max_shape = torch.tensor(shapes).amax(dim=0) diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 3b14f1908b9..c9c14e399dd 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -3,9 +3,11 @@ import torch.distributed as dist import ignite.distributed as idist -from ignite.distributed.utils import sync +from ignite.distributed.utils import all_gather_tensors_with_shapes, sync from ignite.engine import Engine, Events +torch.manual_seed(41) + def _sanity_check(): from ignite.distributed.utils import _model @@ -192,6 +194,11 @@ def _test_distrib_all_gather(device): true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1) assert (res == true_res).all() + ts = [torch.randn(tuple(torch.randint(1, 10, (3,))), device=device) for _ in range(idist.get_world_size())] + ts_gathered = all_gather_tensors_with_shapes(ts[rank], [list(t.shape) for t in ts]) + for t, t_gathered in zip(ts, ts_gathered): + assert (t == t_gathered).all() + if idist.get_world_size() > 1: with pytest.raises(TypeError, match=r"Unhandled input type"): idist.all_reduce([0, 1, 2]) @@ -226,6 +233,14 @@ def _test_distrib_all_gather_group(device): else: assert res == t + ts = [torch.randn(tuple(torch.randint(1, 10, (3,))), device=device) for _ in range(idist.get_world_size())] + ts_gathered = all_gather_tensors_with_shapes(ts[rank], [list(t.shape) for t in ts], ranks) + if rank in ranks: + for i, r in enumerate(ranks): + assert (ts[r] == ts_gathered[i]).all() + else: + assert ts_gathered == [ts[rank]] + if bnd in ("nccl", "gloo", "mpi"): with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"): res = idist.all_gather(t, group="abc") From 6070e18c12376064638d65d1b5ab79ef569089d5 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 23 Aug 2023 20:41:22 +0330 Subject: [PATCH 10/41] A few improvements --- ignite/metrics/metric.py | 16 +++---- ignite/metrics/vision/object_detection_map.py | 46 +++++++++---------- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 330311c8a78..59e0a835d61 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -295,17 +295,17 @@ def iteration_completed(self, engine: Engine) -> None: ) output = tuple(output[k] for k in self.required_output_keys) - if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]): + if isinstance(output, Sequence) and all([isinstance(o, (list, tuple)) for o in output]): if not (len(output) == 2 and len(output[0]) == len(output[1])): raise ValueError( f"Output should have 2 items of the same length, " f"got {len(output)} and {len(output[0])}, {len(output[1])}" ) for o1, o2 in zip(output[0], output[1]): - # o1 and o2 are list of tensors or numbers - tensor_o1 = _to_batched_tensor(o1) - tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device) - self.update((tensor_o1, tensor_o2)) + if isinstance(o1, (torch.Tensor, Number)) and isinstance(o2, (torch.Tensor, Number)): + o1 = _to_batched_tensor(o1) + o2 = _to_batched_tensor(o2, device=o1.device) + self.update((o1, o2)) else: self.update(output) @@ -612,11 +612,7 @@ def wrapper(self: Metric, *args: Any, **kwargs: Any) -> None: return wrapper -def _is_list_of_tensors_or_numbers(x: Sequence[Union[torch.Tensor, float]]) -> bool: - return isinstance(x, Sequence) and all([isinstance(t, (torch.Tensor, Number)) for t in x]) - - -def _to_batched_tensor(x: Union[torch.Tensor, float], device: Optional[torch.device] = None) -> torch.Tensor: +def _to_batched_tensor(x: Union[torch.Tensor, Number], device: Optional[torch.device] = None) -> torch.Tensor: if isinstance(x, torch.Tensor): return x.unsqueeze(dim=0) return torch.tensor([x], device=device) diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index 16d5bd42680..a9e99c31d8f 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -30,7 +30,7 @@ def __init__( Args: iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. - flavor: string values so that metric computation recipe correspond to its respective flavor. For now, only + flavor: string value so that metric computation recipe correspond to its respective flavor. For now, only available option is 'COCO'. Default 'COCO'. rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. @@ -202,9 +202,9 @@ def _do_matching( Returns: `(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores. """ - labels = target["labels"].detach() - pred_labels = pred["labels"].detach() - pred_scores = pred["scores"].detach() + labels = target["labels"] + pred_labels = pred["labels"] + pred_scores = pred["scores"] categories = list(set(labels.int().tolist() + pred_labels.int().tolist())) pred_boxes = pred["bbox"] @@ -288,28 +288,28 @@ def update(self, output: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]] are as follows. N\ :sub:`det` and N\ :sub:`gt` are number of detections and ground truths respectively. - ======= ================== ================================================= + ======== ================== ================================================= **y_pred items** - ------------------------------------------------------------------------------ - Key Value shape Description - ======= ================== ================================================= - 'bbox' (N\ :sub:`det`, 4) Bounding boxes of form (x1, y1, x2, y2) - containing top left and bottom right coordinates. - 'score' (N\ :sub:`det`,) Confidence score of detections. - 'label' (N\ :sub:`det`,) Predicted category number of detections. - ======= ================== ================================================= - - ========= ================== ================================================= + ----------------------------------------------------------------------------- + Key Value shape Description + ======== ================== ================================================= + 'bbox' (N\ :sub:`det`, 4) Bounding boxes of form (x1, y1, x2, y2) + containing top left and bottom right coordinates. + 'scores' (N\ :sub:`det`,) Confidence score of detections. + 'labels' (N\ :sub:`det`,) Predicted category number of detections. + ======== ================== ================================================= + + ========= ================= ================================================= **y items** - ------------------------------------------------------------------------------ + ----------------------------------------------------------------------------- Key Value shape Description - ========= ================== ================================================= - 'bbox' (N\ :sub:`gt`, 4) Bounding boxes of form (x1, y1, x2, y2) - containing top left and bottom right coordinates. - 'label' (N\ :sub:`gt`,) Category number of ground truths. - 'iscrowd' (N\ :sub:`gt`,) Whether ground truth boxes are crowd ones or not. - It's optional with default value of ``False``. - ========= ================== ================================================= + ========= ================= ================================================= + 'bbox' (N\ :sub:`gt`, 4) Bounding boxes of form (x1, y1, x2, y2) + containing top left and bottom right coordinates. + 'labels' (N\ :sub:`gt`,) Category number of ground truths. + 'iscrowd' (N\ :sub:`gt`,) Whether ground truth boxes are crowd ones or not. + This key is optional. + ========= ================= ================================================= """ self._check_matching_input(output) tps, fps, ps, scores_dict = self._do_matching(output[0], output[1]) From deebbdec820ede36400175468ab016c779c2d27c Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 1 Sep 2023 01:09:05 +0330 Subject: [PATCH 11/41] Add an output transform and apply a review comment --- ignite/contrib/handlers/tqdm_logger.py | 2 +- ignite/metrics/mean_average_precision.py | 8 ++-- ignite/metrics/vision/object_detection_map.py | 43 ++++++++++++++++--- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/ignite/contrib/handlers/tqdm_logger.py b/ignite/contrib/handlers/tqdm_logger.py index 6704981ace5..7f38c31fb02 100644 --- a/ignite/contrib/handlers/tqdm_logger.py +++ b/ignite/contrib/handlers/tqdm_logger.py @@ -90,7 +90,7 @@ class ProgressBar(BaseLogger): Note: - When adding attaching the progress bar to an engine, it is recommend that you replace + When attaching the progress bar to an engine, it is recommend that you replace every print operation in the engine's handlers triggered every iteration with ``pbar.log_message`` to guarantee the correct format of the stdout. diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index d9be7e49593..e017aea0677 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -17,7 +17,7 @@ def __init__( rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, average: Optional[str] = "precision", class_mean: Optional[Literal["micro", "macro", "weighted", "with_other_dims"]] = "macro", - classification_is_multilabel: bool = False, + is_multilabel: bool = False, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: @@ -85,7 +85,7 @@ def __init__( Note: Please note that classes with no ground truth are not considered into the mean in detection. - classification_is_multilabel: Used in classification task and determines if the data + is_multilabel: Used in classification task and determines if the data is multilabel or not. Default False. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the @@ -109,7 +109,7 @@ def __init__( raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") self.class_mean = class_mean - super().__init__(output_transform=output_transform, is_multilabel=classification_is_multilabel, device=device) + super().__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device) def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor: if isinstance(thresholds, Sequence): @@ -252,7 +252,7 @@ def __init__( average=average, class_mean=class_mean, output_transform=output_transform, - classification_is_multilabel=is_multilabel, + is_multilabel=is_multilabel, device=device, ) diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index a9e99c31d8f..ef1a44f34c7 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -11,6 +11,37 @@ from ignite.metrics.recall import _BasePrecisionRecall +def tensor_list_to_dict_list( + output: Tuple[ + Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]], + Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]], + ] +) -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]: + """Convert either of output's `y_pred` or `y` from list of `(N, 6)` tensors to list of str-to-tensor dictionaries, + or keep them unchanged if they're already in the deisred format. + + Input format is a `(N, 6)` or (`N, 5)` tensor which `N` is the number of predicted/target bounding boxes for the + image and the second dimension contains `(x1, y1, x2, y2, confidence, class)`/`(x1, y1, x2, y2, class[, iscrowd])`. + Output format is a str-to-tensor dictionary containing 'bbox' and `class` keys, plus `confidence` key for `y_pred` + and possibly `iscrowd` for `y`. + + Args: + output: `(y_pred,y)` tuple whose members are either list of tensors or list of dicts. + + Returns: + `(y_pred,y)` tuple whose members are list of str-to-tensor dictionaries. + """ + y_pred, y = output + if len(y_pred) > 0 and isinstance(y_pred[0], torch.Tensor): + y_pred = [{"bbox": t[:, :4], "confidence": t[:, 4], "class": t[:, 5]} for t in cast(List[torch.Tensor], y_pred)] + if len(y) > 0 and isinstance(y[0], torch.Tensor): + if y[0].size(1) == 5: + y = [{"bbox": t[:, :4], "class": t[:, 4]} for t in cast(List[torch.Tensor], y)] + else: + y = [{"bbox": t[:, :4], "class": t[:, 4], "iscrowd": t[:, 5]} for t in cast(List[torch.Tensor], y)] + return cast(Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]], (y_pred, y)) + + class ObjectDetectionMAP(_BaseMeanAveragePrecision): _tp: Dict[int, List[torch.Tensor]] _fp: Dict[int, List[torch.Tensor]] @@ -25,7 +56,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: - r"""Calculate the mean average precision for evaluating an object detector. + """Calculate the mean average precision for evaluating an object detector. Args: iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision. @@ -34,11 +65,11 @@ def __init__( available option is 'COCO'. Default 'COCO'. rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. - output_transform: a callable that is used to transform the - :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the - form expected by the metric. This can be useful if, for example, you have a multi-output model and - you want to compute the metric with respect to one of the outputs. - By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s + ``process_function``'s output into the form expected by the metric. An already + provided example is :func:`~ignite.metrics.vision.object_detection_map.tensor_list_to_dict_list` + which accepts `y_pred` and `y` as lists of tensors and transforms them to the expected format. + Default is the identity function. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. From 62ca5fb256eccca27543503739a0a1eb566ae8a4 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 1 Sep 2023 02:04:13 +0330 Subject: [PATCH 12/41] Add a test for the output_transform --- .../vision/test_object_detection_map.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 2ea6675d552..d586998a1a8 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -18,6 +18,7 @@ import ignite.distributed as idist from ignite.engine import Engine from ignite.metrics import ObjectDetectionMAP +from ignite.metrics.vision.object_detection_map import tensor_list_to_dict_list from ignite.utils import manual_seed torch.set_printoptions(linewidth=200) @@ -961,6 +962,48 @@ def update(engine, i): assert np.allclose(res_50_95, pycoco_res_50_95) +def test_tensor_list_to_dict_list(): + y_preds = [ + [torch.randn((2, 6)), torch.randn((0, 6)), torch.randn((5, 6))], + [ + {"bbox": torch.randn((2, 4)), "confidence": torch.randn((2,)), "class": torch.randn((2,))}, + {"bbox": torch.randn((5, 4)), "confidence": torch.randn((5,)), "class": torch.randn((5,))}, + ], + ] + ys = [ + [torch.randn((2, 6)), torch.randn((0, 6)), torch.randn((5, 6))], + [ + {"bbox": torch.randn((2, 4)), "class": torch.randn((2,))}, + {"bbox": torch.randn((5, 4)), "class": torch.randn((5,))}, + ], + ] + for y_pred in y_preds: + for y in ys: + y_pred_new, y_new = tensor_list_to_dict_list((y_pred, y)) + if isinstance(y_pred[0], dict): + assert y_pred_new is y_pred + else: + assert all( + [ + (ypn["bbox"] == yp[:, :4]).all() + & (ypn["confidence"] == yp[:, 4]).all() + & (ypn["class"] == yp[:, 5]).all() + for yp, ypn in zip(y_pred, y_pred_new) + ] + ) + if isinstance(y[0], dict): + assert y_new is y + else: + assert all( + [ + (ytn["bbox"] == yt[:, :4]).all() + & (ytn["class"] == yt[:, 4]).all() + & (ytn["iscrowd"] == yt[:, 5]).all() + for yt, ytn in zip(y, y_new) + ] + ) + + def test_distrib_update_compute(distributed, sample): rank_samples_cnt = ceil(sample.length / idist.get_world_size()) rank = idist.get_rank() From 418fcf4b2b70ade2cd60c4b1ca356d40388ae5d6 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 1 Sep 2023 14:32:32 +0330 Subject: [PATCH 13/41] Remove 'flavor' because all DeciAI, Ultralytics, Detectron and pycocotools use the 'max-precision' approach --- ignite/metrics/vision/object_detection_map.py | 11 ++--------- .../metrics/vision/test_object_detection_map.py | 3 --- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index ef1a44f34c7..01b9665e7d1 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Callable, cast, Dict, List, Literal, Optional, Sequence, Tuple, Union +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -51,7 +51,6 @@ class ObjectDetectionMAP(_BaseMeanAveragePrecision): def __init__( self, iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - flavor: Optional["Literal['COCO']"] = "COCO", rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), @@ -61,8 +60,6 @@ def __init__( Args: iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. - flavor: string value so that metric computation recipe correspond to its respective flavor. For now, only - available option is 'COCO'. Default 'COCO'. rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s @@ -88,10 +85,6 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo except ImportError: raise ModuleNotFoundError("This metric requires torchvision to be installed.") - if flavor != "COCO": - raise ValueError(f"Currently, the only available flavor for ObjectDetectionMAP is 'COCO', given {flavor}") - self.flavor = flavor - if iou_thresholds is None: iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double) @@ -102,7 +95,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo super().__init__( rec_thresholds=rec_thresholds, - average="max-precision" if flavor == "COCO" else "precision", + average="max-precision", class_mean="with_other_dims", output_transform=output_transform, device=device, diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index d586998a1a8..8c5fa0ca5e7 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -625,9 +625,6 @@ def sample(request) -> Sample: def test_wrong_input(): - with pytest.raises(ValueError, match="Currently, the only available flavor for ObjectDetectionMAP is 'COCO'"): - ObjectDetectionMAP(flavor="wrong flavor") - m = ObjectDetectionMAP() with pytest.raises(ValueError, match="y_pred dict in update's input should have 'bbox', 'scores'"): m.update(({"bbox": None, "scores": None}, {"bbox": None, "labels": None})) From 79fa1e29fae7c4cc1423853b98fb14dfeea21596 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Mon, 11 Sep 2023 02:23:45 +0330 Subject: [PATCH 14/41] Revert Metric change and a few bug fix --- ignite/metrics/metric.py | 14 ++- ignite/metrics/vision/object_detection_map.py | 69 +++++++------- .../vision/test_object_detection_map.py | 93 +++++++++++-------- 3 files changed, 101 insertions(+), 75 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index f6c2a8b23c7..eddec7f3a06 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -389,17 +389,17 @@ def iteration_completed(self, engine: Engine) -> None: ) output = tuple(output[k] for k in self.required_output_keys) - if isinstance(output, Sequence) and all([isinstance(o, (list, tuple)) for o in output]): + if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]): if not (len(output) == 2 and len(output[0]) == len(output[1])): raise ValueError( f"Output should have 2 items of the same length, " f"got {len(output)} and {len(output[0])}, {len(output[1])}" ) for o1, o2 in zip(output[0], output[1]): - if isinstance(o1, (torch.Tensor, Number)) and isinstance(o2, (torch.Tensor, Number)): - o1 = _to_batched_tensor(o1) - o2 = _to_batched_tensor(o2, device=o1.device) - self.update((o1, o2)) + # o1 and o2 are list of tensors or numbers + tensor_o1 = _to_batched_tensor(o1) + tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device) + self.update((tensor_o1, tensor_o2)) else: self.update(output) @@ -757,6 +757,10 @@ def wrapper(self: Metric, *args: Any, **kwargs: Any) -> None: return wrapper +def _is_list_of_tensors_or_numbers(x: Sequence[Union[torch.Tensor, float]]) -> bool: + return isinstance(x, Sequence) and all([isinstance(t, (torch.Tensor, Number)) for t in x]) + + def _to_batched_tensor(x: Union[torch.Tensor, Number], device: Optional[torch.device] = None) -> torch.Tensor: if isinstance(x, torch.Tensor): return x.unsqueeze(dim=0) diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index 01b9665e7d1..51383e7b6bc 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -59,9 +59,9 @@ def __init__( Args: iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision. - Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. + Values should be between 0 and 1. If not given, COCO's default (.5, .55, ..., .95) would be used. rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. - Values should be between 0 and 1. If not given, it's determined by ``flavor`` argument. + Values should be between 0 and 1. If not given, COCO's default (.0, .01, .02, ..., 1.) would be used. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. An already provided example is :func:`~ignite.metrics.vision.object_detection_map.tensor_list_to_dict_list` @@ -114,19 +114,27 @@ def reset(self) -> None: self._P = defaultdict(lambda: 0) self._num_classes: int = 0 - def _check_matching_input(self, output: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]) -> None: + def _check_matching_input( + self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]] + ) -> None: + y_pred, y = output + if len(y_pred) != len(y): + raise ValueError(f"y_pred and y should have the same number of samples, given {len(y_pred)} and {len(y)}.") + if len(y_pred) == 0: + raise ValueError("y_pred and y should contain at least one sample.") + y_pred_keys = {"bbox", "scores", "labels"} - if (output[0].keys() & y_pred_keys) != y_pred_keys: + if (y_pred[0].keys() & y_pred_keys) != y_pred_keys: raise ValueError( - "y_pred dict in update's input should have 'bbox', 'scores'" - f" and 'labels' keys. It has {output[0].keys()}" + "y_pred sample dictionaries should have 'bbox', 'scores'" + f" and 'labels' keys, given keys: {y_pred[0].keys()}" ) y_keys = {"bbox", "labels"} - if (output[1].keys() & y_keys) != y_keys: + if (y[0].keys() & y_keys) != y_keys: raise ValueError( - "y dict in update's input should have 'bbox', 'labels'" - f" and optionaly 'iscrowd' keys. It has {output[1].keys()}" + "y sample dictionaries should have 'bbox', 'labels'" + f" and optionally 'iscrowd' keys, given keys: {y[0].keys()}" ) def _compute_recall_and_precision( @@ -172,7 +180,7 @@ def _compute_recall_and_precision( def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: """Measuring average precision. This method is overriden since :math:`1/#recall_thresholds` is used instead of :math:`r_k - r_{k-1}` - as the recall differential in COCO flavor. + as the recall differential in COCO's reference implementation i.e., pycocotools. Args: recall: n-dimensional tensor whose last dimension is the dimension of the samples. Should be ordered in @@ -182,9 +190,6 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens Returns: average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ - if self.flavor != "COCO": - return super()._compute_average_precision(recall, precision) - precision_integrand = ( precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision ) @@ -304,13 +309,14 @@ def _do_matching( return tp, fp, P, scores @reinit__is_reduced - def update(self, output: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]) -> None: + def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]) -> None: r"""Metric update function using prediction and target. Args: - output: a binary tuple of str-to-tensor dictionaries, (y_pred, y), which their items - are as follows. N\ :sub:`det` and N\ :sub:`gt` are number of detections and - ground truths respectively. + output: a tuple, (y_pred, y), of two same-length lists, each one containing + str-to-tensor dictionaries whose items is as follows. N\ :sub:`det` and + N\ :sub:`gt` are number of detections and ground truths for a sample + respectively. ======== ================== ================================================= **y_pred items** @@ -336,22 +342,23 @@ def update(self, output: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]] ========= ================= ================================================= """ self._check_matching_input(output) - tps, fps, ps, scores_dict = self._do_matching(output[0], output[1]) - for cls in tps: - self._tp[cls].append(tps[cls].to(device=self._device, dtype=torch.uint8)) - self._fp[cls].append(fps[cls].to(device=self._device, dtype=torch.uint8)) - self._scores[cls].append(scores_dict[cls].to(self._device)) - for cls in ps: - self._P[cls] += ps[cls] - classes = tps.keys() | ps.keys() - if classes: - self._num_classes = max(max(classes) + 1, self._num_classes) - - def compute(self) -> Union[torch.Tensor, float]: + for y_pred, y in zip(*output): + tps, fps, ps, scores_dict = self._do_matching(y_pred, y) + for cls in tps: + self._tp[cls].append(tps[cls].to(device=self._device, dtype=torch.uint8)) + self._fp[cls].append(fps[cls].to(device=self._device, dtype=torch.uint8)) + self._scores[cls].append(scores_dict[cls].to(self._device)) + for cls in ps: + self._P[cls] += ps[cls] + classes = tps.keys() | ps.keys() + if classes: + self._num_classes = max(max(classes) + 1, self._num_classes) + + def compute(self) -> float: """ Compute method of the metric """ - if sum(self._P.values()) < 1 and self.flavor == "COCO": + if sum(self._P.values()) < 1: return -1 num_classes = int(idist.all_reduce(self._num_classes or 0, "MAX")) @@ -431,4 +438,4 @@ def compute(self) -> Union[torch.Tensor, float]: average_precision_for_cls_across_other_dims = self._compute_average_precision(recall, precision) average_precisions[cls] = average_precision_for_cls_across_other_dims - return average_precisions[average_precisions > -1].mean() + return average_precisions[average_precisions > -1].mean().item() diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 8c5fa0ca5e7..2c60458b355 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -1,7 +1,7 @@ import sys from collections import namedtuple from math import ceil -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from unittest.mock import patch import numpy as np @@ -626,10 +626,15 @@ def sample(request) -> Sample: def test_wrong_input(): m = ObjectDetectionMAP() - with pytest.raises(ValueError, match="y_pred dict in update's input should have 'bbox', 'scores'"): - m.update(({"bbox": None, "scores": None}, {"bbox": None, "labels": None})) - with pytest.raises(ValueError, match="y dict in update's input should have 'bbox', 'labels'"): - m.update(({"bbox": None, "scores": None, "labels": None}, {"labels": None})) + + with pytest.raises(ValueError, match="y_pred and y should have the same number of samples"): + m.update(([{"bbox": None, "scores": None}], [])) + with pytest.raises(ValueError, match="y_pred and y should contain at least one sample."): + m.update(([], [])) + with pytest.raises(ValueError, match="y_pred sample dictionaries should have 'bbox', 'scores'"): + m.update(([{"bbox": None, "scores": None}], [{"bbox": None, "labels": None}])) + with pytest.raises(ValueError, match="y sample dictionaries should have 'bbox', 'labels'"): + m.update(([{"bbox": None, "scores": None, "labels": None}], [{"labels": None}])) def test_empty_data(): @@ -640,8 +645,8 @@ def test_empty_data(): metric = ObjectDetectionMAP() metric.update( ( - {"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}, - {"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))}, + [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}], + [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))}], ) ) assert len(metric._tp) == 0 @@ -653,12 +658,14 @@ def test_empty_data(): metric = ObjectDetectionMAP() metric.update( ( - {"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}, - { - "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), - "iscrowd": torch.zeros((1,)), - "labels": torch.ones((1,)), - }, + [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}], + [ + { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "iscrowd": torch.zeros((1,)), + "labels": torch.ones((1,)), + } + ], ) ) assert len(metric._tp) == 0 @@ -674,7 +681,7 @@ def test_empty_data(): "labels": torch.tensor([5]), } target = {"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))} - metric.update((pred, target)) + metric.update(([pred], [target])) assert (5 in metric._tp) and metric._tp[5][0].shape[1] == 1 assert (5 in metric._fp) and metric._fp[5][0].shape[1] == 1 assert len(metric._P) == 0 @@ -718,7 +725,7 @@ def test_iou_thresholding(): "labels": torch.tensor([1]), } gt = {"bbox": torch.tensor([[0.0, 0.0, 50.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])} - metric.update((pred, gt)) + metric.update(([pred], [gt])) assert (metric._tp[1][0] == torch.tensor([[True], [True], [True], [False]])).all() pred = { @@ -727,7 +734,7 @@ def test_iou_thresholding(): "labels": torch.tensor([1]), } gt = {"bbox": torch.tensor([[100.0, 0.0, 200.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])} - metric.update((pred, gt)) + metric.update(([pred], [gt])) assert (metric._tp[1][1] == torch.tensor([[True], [False], [False], [False]])).all() @@ -747,7 +754,7 @@ def test__do_matching_output(sample): assert tps[cls].dtype in (torch.bool, torch.uint8) assert tps[cls].size(-1) == fps[cls].size(-1) == scores[cls].size(0) - metric.update((pred, target)) + metric.update(([pred], [target])) for metric_tp_or_fp, new_tp_or_fp in [(metric._tp, tps), (metric._fp, fps)]: try: @@ -759,10 +766,12 @@ def test__do_matching_output(sample): class Dummy_mAP(ObjectDetectionMAP): - def _do_matching(self, pred: Tuple, target: Tuple): - return *pred, *target + def _do_matching(self, tup1: Tuple, tup2: Tuple): + tp, fp = tup1 + p, score = tup2 + return tp, fp, p, score - def _check_matching_input(self, output: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]): + def _check_matching_input(self, output: Any): pass @@ -770,12 +779,14 @@ def test_update(): metric = Dummy_mAP() assert len(metric._tp) == len(metric._fp) == len(metric._scores) == len(metric._P) == metric._num_classes == 0 - metric.update((({1: torch.tensor([True])}, {1: torch.tensor([False])}), ({1: 1, 2: 1}, {1: torch.tensor([0.8])}))) + metric.update( + ([({1: torch.tensor([True])}, {1: torch.tensor([False])})], [({1: 1, 2: 1}, {1: torch.tensor([0.8])})]) + ) assert len(metric._tp[1]) == len(metric._fp[1]) == len(metric._scores[1]) == 1 assert len(metric._P) == 2 and metric._P[2] == 1 assert metric._num_classes == 3 - metric.update((({}, {}), ({2: 2}, {}))) + metric.update(([({}, {})], [({2: 2}, {})])) assert metric._P[2] == 3 @@ -806,7 +817,7 @@ def test_matching(): "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1]), } - metric.update((rule_1_pred, rule_1_gt)) + metric.update(([rule_1_pred], [rule_1_gt])) assert (metric._tp[1][0] == torch.tensor([[False, True]])).all() assert (metric._fp[1][0] == torch.tensor([[True, False]])).all() @@ -820,7 +831,7 @@ def test_matching(): "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1]), } - metric.update((rule_1_and_2_pred, rule_1_and_2_gt)) + metric.update(([rule_1_and_2_pred], [rule_1_and_2_gt])) assert (metric._tp[1][1] == torch.tensor([[True, False]])).all() assert (metric._fp[1][1] == torch.tensor([[False, True]])).all() @@ -834,7 +845,7 @@ def test_matching(): "iscrowd": torch.tensor([1]), "labels": torch.tensor([1]), } - metric.update((rule_2_pred, rule_2_gt)) + metric.update(([rule_2_pred], [rule_2_gt])) assert (metric._tp[1][2] == torch.tensor([[False, False]])).all() assert (metric._fp[1][2] == torch.tensor([[False, False]])).all() @@ -848,7 +859,7 @@ def test_matching(): "iscrowd": torch.zeros((2,)), "labels": torch.tensor([1, 1]), } - metric.update((rule_2_and_3_pred, rule_2_and_3_gt)) + metric.update(([rule_2_and_3_pred], [rule_2_and_3_gt])) assert (metric._tp[1][3] == torch.tensor([[True, False]])).all() assert (metric._fp[1][3] == torch.tensor([[False, True]])).all() @@ -916,14 +927,13 @@ def test_compute(sample): metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=device) metric_50_95 = ObjectDetectionMAP(device=device) - for prediction, target in zip(*sample.data): - metric_50.update((prediction, target)) - metric_75.update((prediction, target)) - metric_50_95.update((prediction, target)) + metric_50.update(sample.data) + metric_75.update(sample.data) + metric_50_95.update(sample.data) - res_50 = metric_50.compute().item() - res_75 = metric_75.compute().item() - res_50_95 = metric_50_95.compute().item() + res_50 = metric_50.compute() + res_75 = metric_75.compute() + res_50_95 = metric_50_95.compute() pycoco_res_50_95, pycoco_res_50, pycoco_res_75 = sample.mAP @@ -941,8 +951,11 @@ def test_compute(sample): def test_integration(sample): + bs = 3 + def update(engine, i): - return sample.data[0][i], sample.data[1][i] + b = slice(i * bs, (i + 1) * bs) + return sample.data[0][b], sample.data[1][b] engine = Engine(update) @@ -951,7 +964,8 @@ def update(engine, i): metric_50_95 = ObjectDetectionMAP(device=metric_device) metric_50_95.attach(engine, name="mAP[50-95]") - engine.run(range(sample.length), max_epochs=1) + n_iter = ceil(sample.length / bs) + engine.run(range(n_iter), max_epochs=1) res_50_95 = engine.state.metrics["mAP[50-95]"] pycoco_res_50_95 = sample.mAP[0] @@ -1012,10 +1026,11 @@ def test_distrib_update_compute(distributed, sample): metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=metric_device) metric_50_95 = ObjectDetectionMAP(device=metric_device) - for prediction, target in zip(sample.data[0][rank_samples_range], sample.data[1][rank_samples_range]): - metric_50.update((prediction, target)) - metric_75.update((prediction, target)) - metric_50_95.update((prediction, target)) + y_pred_rank = sample.data[0][rank_samples_range] + y_rank = sample.data[1][rank_samples_range] + metric_50.update((y_pred_rank, y_rank)) + metric_75.update((y_pred_rank, y_rank)) + metric_50_95.update((y_pred_rank, y_rank)) res_50 = metric_50.compute() res_75 = metric_75.compute() From 26c96b8800e27473e8378f5fddaab9e4c3c2cfc9 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Sat, 16 Sep 2023 01:30:46 +0330 Subject: [PATCH 15/41] A tiny improvement in local variable names --- ignite/metrics/vision/object_detection_map.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index 51383e7b6bc..324fc445f79 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -378,13 +378,13 @@ def compute(self) -> float: ) if num_preds_per_class_across_ranks.sum() == 0: return 0.0 - a_nonempty_rank, its_class_with_pred = list(zip(*torch.where(num_preds_per_class_across_ranks != 0))).pop(0) + a_nonempty_rank, a_nonempty_class = list(zip(*torch.where(num_preds_per_class_across_ranks != 0))).pop(0) a_nonempty_rank = a_nonempty_rank.item() - its_class_with_pred = its_class_with_pred.item() + a_nonempty_class = a_nonempty_class.item() mean_dimensions_shape = cast( torch.Tensor, idist.broadcast( - torch.tensor(self._tp[its_class_with_pred][-1].shape[:-1], device=self._device) + torch.tensor(self._tp[a_nonempty_class][-1].shape[:-1], device=self._device) if idist.get_rank() == a_nonempty_rank else None, a_nonempty_rank, From a361ca8b080d00761a24c4f19309a9fcf586121b Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Mon, 4 Dec 2023 18:01:09 +0330 Subject: [PATCH 16/41] Add max_dep and area_range --- ignite/metrics/mean_average_precision.py | 4 +- ignite/metrics/vision/object_detection_map.py | 115 ++++++++----- .../vision/test_object_detection_map.py | 151 ++++++++---------- 3 files changed, 144 insertions(+), 126 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index e017aea0677..e98666ba8d5 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -15,7 +15,7 @@ class _BaseMeanAveragePrecision(_BasePrecisionRecall): def __init__( self, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - average: Optional[str] = "precision", + average: Optional[Literal["precision", "max-precision"]] = "precision", class_mean: Optional[Literal["micro", "macro", "weighted", "with_other_dims"]] = "macro", is_multilabel: bool = False, output_transform: Callable = lambda x: x, @@ -32,7 +32,7 @@ def __init__( rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need to be sorted. If missing, thresholds are considered automatically using the data. - average: one of values precision or max-precision. In the former case, the precision at a + average: one of values "precision" or "max-precision". In the former case, the precision at a recall threshold is used for that threshold: .. math:: diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_map.py index 324fc445f79..44fdb1a38db 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_map.py @@ -2,6 +2,7 @@ from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch +from typing_extensions import Literal import ignite.distributed as idist from ignite.distributed.utils import all_gather_tensors_with_shapes @@ -52,6 +53,8 @@ def __init__( self, iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + max_detections_per_image: Optional[int] = 100, + area_range: Optional[Literal["small", "medium", "large", "all"]] = "all", output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: @@ -93,6 +96,9 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo if rec_thresholds is None: rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=torch.double) + self.area_range = area_range + self.max_detections_per_image = max_detections_per_image + super().__init__( rec_thresholds=rec_thresholds, average="max-precision", @@ -114,6 +120,24 @@ def reset(self) -> None: self._P = defaultdict(lambda: 0) self._num_classes: int = 0 + def _match_area_range(self, bboxes: torch.Tensor) -> torch.Tensor: + from torchvision.ops.boxes import box_area + + areas = box_area(bboxes) + if self.area_range == "all": + min_area = 0 + max_area = 1e10 + elif self.area_range == "small": + min_area = 0 + max_area = 1024 + elif self.area_range == "medium": + min_area = 1024 + max_area = 9216 + elif self.area_range == "large": + min_area = 9216 + max_area = 1e10 + return torch.logical_and(areas >= min_area, areas <= max_area) + def _check_matching_input( self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]] ) -> None: @@ -232,14 +256,20 @@ def _do_matching( `(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores. """ labels = target["labels"] - pred_labels = pred["labels"] - pred_scores = pred["scores"] - categories = list(set(labels.int().tolist() + pred_labels.int().tolist())) - - pred_boxes = pred["bbox"] gt_boxes = target["bbox"] + gt_is_crowd = ( + target["iscrowd"].bool() if "iscrowd" in target else torch.zeros_like(target["labels"], dtype=torch.bool) + ) + gt_ignore = ~self._match_area_range(gt_boxes) | gt_is_crowd + + best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True)[ + : self.max_detections_per_image + ] + pred_scores = pred["scores"][best_detections_index] + pred_labels = pred["labels"][best_detections_index] + pred_boxes = pred["bbox"][best_detections_index] - is_crowd = target["iscrowd"] if "iscrowd" in target else torch.zeros_like(target["labels"], dtype=torch.bool) + categories = list(set(labels.int().tolist() + pred_labels.int().tolist())) tp: Dict[int, torch.Tensor] = {} fp: Dict[int, torch.Tensor] = {} @@ -247,61 +277,62 @@ def _do_matching( scores: Dict[int, torch.Tensor] = {} for category in categories: - class_index_gt = labels == category - num_category_gt = class_index_gt.sum() - category_is_crowd = is_crowd[class_index_gt] - if num_category_gt: - P[category] = num_category_gt - category_is_crowd.sum() - - class_index_dt = pred_labels == category - if not class_index_dt.any(): + category_index_gt = labels == category + num_category_gt = category_index_gt.sum() + category_is_crowd = gt_is_crowd[category_index_gt] + category_gt_ignore = gt_ignore[category_index_gt] + if num_category_gt: # what if P[c] becomes 0 ? + P[category] = num_category_gt - category_gt_ignore.sum() + + category_index_dt = pred_labels == category + if not category_index_dt.any(): continue - scores[category] = pred_scores[class_index_dt] - + scores[category] = pred_scores[category_index_dt] + pred_match_area_range = self._match_area_range(pred_boxes[category_index_dt]) + num_category_dt = category_index_dt.sum().item() category_tp = torch.zeros( - (len(self.iou_thresholds), class_index_dt.sum().item()), dtype=torch.uint8, device=self._device + (len(self.iou_thresholds), num_category_dt), dtype=torch.uint8, device=self._device ) category_fp = torch.zeros( - (len(self.iou_thresholds), class_index_dt.sum().item()), dtype=torch.uint8, device=self._device + (len(self.iou_thresholds), num_category_dt), dtype=torch.uint8, device=self._device ) if num_category_gt: - class_iou = self.box_iou( - pred_boxes[class_index_dt], - gt_boxes[class_index_gt], - cast(torch.BoolTensor, category_is_crowd.bool()), + category_iou = self.box_iou( + pred_boxes[category_index_dt], + gt_boxes[category_index_gt], + cast(torch.BoolTensor, category_is_crowd), ) - class_maximum_iou = class_iou.max() - category_pred_idx_sorted_by_decreasing_score = torch.argsort( - pred_scores[class_index_dt], stable=True, descending=True - ).tolist() + category_maximum_iou = category_iou.max() for thres_idx, iou_thres in enumerate(self.iou_thresholds): - if iou_thres <= class_maximum_iou: + if iou_thres <= category_maximum_iou: matched_gt_indices = set() - for pred_idx in category_pred_idx_sorted_by_decreasing_score: + for pred_idx in range(num_category_dt): match_iou, match_idx = min(iou_thres, 1 - 1e-10), -1 for gt_idx in range(num_category_gt): - if (class_iou[pred_idx][gt_idx] < iou_thres) or ( - gt_idx in matched_gt_indices and torch.logical_not(category_is_crowd[gt_idx]) + if (category_iou[pred_idx, gt_idx] < iou_thres) or ( + gt_idx in matched_gt_indices and ~category_is_crowd[gt_idx] ): continue - if match_idx == -1 or ( - class_iou[pred_idx][gt_idx] >= match_iou - and torch.logical_or( - torch.logical_not(category_is_crowd[gt_idx]), category_is_crowd[match_idx] + if ( + match_idx == -1 + or (category_gt_ignore[match_idx] & ~category_gt_ignore[gt_idx]) + or ( + (category_gt_ignore[match_idx] | ~category_gt_ignore[gt_idx]) + and category_iou[pred_idx][gt_idx] >= match_iou ) ): - match_iou = class_iou[pred_idx][gt_idx] + match_iou = category_iou[pred_idx][gt_idx] match_idx = gt_idx if match_idx != -1: matched_gt_indices.add(match_idx) - category_tp[thres_idx][pred_idx] = torch.logical_not(category_is_crowd[match_idx]) + category_tp[thres_idx][pred_idx] = ~category_gt_ignore[match_idx] else: - category_fp[thres_idx][pred_idx] = 1 + category_fp[thres_idx][pred_idx] = pred_match_area_range[pred_idx] else: - category_fp[thres_idx] = 1 + category_fp[thres_idx] = pred_match_area_range else: - category_fp[:, :] = 1 + category_fp[:, :] = pred_match_area_range tp[category] = category_tp fp[category] = category_fp @@ -358,9 +389,6 @@ def compute(self) -> float: """ Compute method of the metric """ - if sum(self._P.values()) < 1: - return -1 - num_classes = int(idist.all_reduce(self._num_classes or 0, "MAX")) if num_classes < 1: return 0.0 @@ -369,6 +397,9 @@ def compute(self) -> float: torch.Tensor, idist.all_reduce(torch.tensor(list(map(self._P.__getitem__, range(num_classes))), device=self._device)), ) + if P.sum() < 1: + return -1 + num_preds = torch.tensor( [sum([tp.shape[-1] for tp in self._tp[cls]]) if self._tp[cls] else 0 for cls in range(num_classes)], device=self._device, diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 2c60458b355..9abd9214b12 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -531,18 +531,16 @@ def create_coco_api( return coco_dt, coco_gt -def pycoco_mAP( - predictions: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]] -) -> Tuple[float, float, float]: +def pycoco_mAP(predictions: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]]) -> np.array: """ - Returned values belong to IOU thresholds of [0.5, 0.55, ..., 0.95], [0.5] and [0.75] respectively. + Returned values are AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L """ coco_dt, coco_gt = create_coco_api(predictions, targets) eval = COCOeval(coco_gt, coco_dt, iouType="bbox") eval.evaluate() eval.accumulate() eval.summarize() - return eval.stats[0], eval.stats[1], eval.stats[2] + return eval.stats[:6] Sample = namedtuple("Sample", ["data", "mAP", "length"]) @@ -653,6 +651,19 @@ def test_empty_data(): assert len(metric._fp) == 0 assert len(metric._P) == 0 assert metric._num_classes == 0 + assert metric.compute() == 0 + metric.update( + ( + [ + { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "scores": torch.ones((1,)), + "labels": torch.ones((1,)), + } + ], + [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))}], + ) + ) assert metric.compute() == -1 metric = ObjectDetectionMAP() @@ -801,68 +812,59 @@ def test_matching(): 3. Among many plausible ground truth boxes, a prediction is matched with the one which has the highest mutual IOU. If two ground truth boxes have the same IOU with a prediction, the later one is matched. - 4. A non-crowd ground truth has priority over a crowd ground truth in getting + 4. A prediction is matched with an out-of-area-range ground truth box only if there's no + plausible within-area-range ground truth box. In that case the prediction would get ignored. + 5. An unmatched prediction would get ignored if it's out of area range. + 6. A non-crowd ground truth has priority over a crowd ground truth in getting matched with a prediction in the sense that even if the crowd ground truth has a higher IOU, the non-crowd one gets matched if its IOU is viable. """ metric = ObjectDetectionMAP(iou_thresholds=[0.2]) - rule_1_pred = { + pred = { "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]), "scores": torch.tensor([0.8, 0.9]), "labels": torch.tensor([1, 1]), } - rule_1_gt = { + gt = { "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1]), } - metric.update(([rule_1_pred], [rule_1_gt])) - assert (metric._tp[1][0] == torch.tensor([[False, True]])).all() - assert (metric._fp[1][0] == torch.tensor([[True, False]])).all() + metric.update(([pred], [gt])) + assert (metric._tp[1][0] == torch.tensor([[True, False]])).all() + assert (metric._fp[1][0] == torch.tensor([[False, True]])).all() + assert (metric._scores[1][0] == torch.tensor([[0.9, 0.8]])).all() - rule_1_and_2_pred = { - "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]), - "scores": torch.tensor([0.9, 0.9]), - "labels": torch.tensor([1, 1]), - } - rule_1_and_2_gt = { - "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), - "iscrowd": torch.zeros((1,)), - "labels": torch.tensor([1]), - } - metric.update(([rule_1_and_2_pred], [rule_1_and_2_gt])) + pred["scores"] = torch.tensor([0.9, 0.9]) + metric.update(([pred], [gt])) assert (metric._tp[1][1] == torch.tensor([[True, False]])).all() assert (metric._fp[1][1] == torch.tensor([[False, True]])).all() - rule_2_pred = { - "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]), - "scores": torch.tensor([0.9, 0.9]), - "labels": torch.tensor([1, 1]), - } - rule_2_gt = { - "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), - "iscrowd": torch.tensor([1]), - "labels": torch.tensor([1]), - } - metric.update(([rule_2_pred], [rule_2_gt])) + gt["iscrowd"] = torch.tensor([1]) + metric.update(([pred], [gt])) assert (metric._tp[1][2] == torch.tensor([[False, False]])).all() assert (metric._fp[1][2] == torch.tensor([[False, False]])).all() - rule_2_and_3_pred = { - "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [100.0, 0.0, 200.0, 100.0]]), - "scores": torch.tensor([0.9, 0.9]), - "labels": torch.tensor([1, 1]), - } - rule_2_and_3_gt = { - "bbox": torch.tensor([[0.0, 0.0, 25.0, 50.0], [50.0, 0.0, 150.0, 100.0]]), - "iscrowd": torch.zeros((2,)), - "labels": torch.tensor([1, 1]), - } - metric.update(([rule_2_and_3_pred], [rule_2_and_3_gt])) + pred["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 100.0], [100.0, 0.0, 200.0, 100.0]]) + gt["bbox"] = torch.tensor([[0.0, 0.0, 25.0, 50.0], [50.0, 0.0, 150.0, 100.0]]) + gt["iscrowd"] = torch.zeros((2,)) + gt["labels"] = torch.tensor([1, 1]) + metric.update(([pred], [gt])) assert (metric._tp[1][3] == torch.tensor([[True, False]])).all() assert (metric._fp[1][3] == torch.tensor([[False, True]])).all() + metric.area_range = "small" + pred["bbox"] = torch.tensor( + [[0.0, 0.0, 100.0, 10.0], [0.0, 0.0, 100.0, 10.0], [0.0, 0.0, 100.0, 11.0], [0.0, 0.0, 100.0, 10.0]] + ) + pred["scores"] = torch.tensor([0.9, 0.9, 0.9, 0.9]) + pred["labels"] = torch.tensor([1, 1, 1, 1]) + gt["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 11.0], [0.0, 0.0, 100.0, 5.0]]) + metric.update(([pred], [gt])) + assert (metric._tp[1][4] == torch.tensor([[True, False, False, False]])).all() + assert (metric._fp[1][4] == torch.tensor([[False, False, False, True]])).all() + def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score): y_true = y_true == 1 @@ -923,31 +925,24 @@ def test__compute_recall_and_precision(): def test_compute(sample): device = idist.device() + metric_50_95 = ObjectDetectionMAP(device=device) metric_50 = ObjectDetectionMAP(iou_thresholds=[0.5], device=device) metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=device) - metric_50_95 = ObjectDetectionMAP(device=device) + metric_S = ObjectDetectionMAP(device=device, area_range="small") + metric_M = ObjectDetectionMAP(device=device, area_range="medium") + metric_L = ObjectDetectionMAP(device=device, area_range="large") - metric_50.update(sample.data) - metric_75.update(sample.data) - metric_50_95.update(sample.data) + metrics = [metric_50_95, metric_50, metric_75, metric_S, metric_M, metric_L] - res_50 = metric_50.compute() - res_75 = metric_75.compute() - res_50_95 = metric_50_95.compute() + for metric in metrics: + metric.update(sample.data) - pycoco_res_50_95, pycoco_res_50, pycoco_res_75 = sample.mAP + ignite_res = [metric.compute() for metric in metrics] + assert all([np.allclose(re, pycoco_res) for re, pycoco_res in zip(ignite_res, sample.mAP)]) - assert np.allclose(res_50, pycoco_res_50) - assert np.allclose(res_75, pycoco_res_75) - assert np.allclose(res_50_95, pycoco_res_50_95) - - res_50_recompute = metric_50.compute() - res_75_recompute = metric_75.compute() - res_50_95_recompute = metric_50_95.compute() + ignite_res_recompute = [metric.compute() for metric in metrics] - assert res_50 == res_50_recompute - assert res_75 == res_75_recompute - assert res_50_95 == res_50_95_recompute + assert all([r1 == r2 for r1, r2 in zip(ignite_res, ignite_res_recompute)]) def test_integration(sample): @@ -1022,30 +1017,22 @@ def test_distrib_update_compute(distributed, sample): device = idist.device() metric_device = "cpu" if device.type == "xla" else device + metric_50_95 = ObjectDetectionMAP(device=metric_device) metric_50 = ObjectDetectionMAP(iou_thresholds=[0.5], device=metric_device) metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=metric_device) - metric_50_95 = ObjectDetectionMAP(device=metric_device) + metric_S = ObjectDetectionMAP(device=metric_device, area_range="small") + metric_M = ObjectDetectionMAP(device=metric_device, area_range="medium") + metric_L = ObjectDetectionMAP(device=metric_device, area_range="large") + + metrics = [metric_50_95, metric_50, metric_75, metric_S, metric_M, metric_L] y_pred_rank = sample.data[0][rank_samples_range] y_rank = sample.data[1][rank_samples_range] - metric_50.update((y_pred_rank, y_rank)) - metric_75.update((y_pred_rank, y_rank)) - metric_50_95.update((y_pred_rank, y_rank)) - - res_50 = metric_50.compute() - res_75 = metric_75.compute() - res_50_95 = metric_50_95.compute() - - pycoco_res_50_95, pycoco_res_50, pycoco_res_75 = sample.mAP - - assert np.allclose(res_50_95, pycoco_res_50_95) - assert np.allclose(res_50, pycoco_res_50) - assert np.allclose(res_75, pycoco_res_75) + for metric in metrics: + metric.update((y_pred_rank, y_rank)) - res_50_recompute = metric_50.compute() - res_75_recompute = metric_75.compute() - res_50_95_recompute = metric_50_95.compute() + ignite_res = [metric.compute() for metric in metrics] + assert all([np.allclose(re, pycoco_res) for re, pycoco_res in zip(ignite_res, sample.mAP)]) - assert res_50_recompute == res_50 - assert res_75_recompute == res_75 - assert res_50_95_recompute == res_50_95 + ignite_res_recompute = [metric.compute() for metric in metrics] + assert all([r1 == r2 for r1, r2 in zip(ignite_res, ignite_res_recompute)]) From ce48583f814cdcf42ba48af4fe579b964ff7272e Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 28 Jun 2024 16:52:19 +0330 Subject: [PATCH 17/41] some improvements --- ignite/metrics/__init__.py | 2 +- ignite/metrics/mean_average_precision.py | 162 +++++--------- ignite/metrics/vision/__init__.py | 2 +- ...ect_detection_average_precision_recall.py} | 208 ++++++++---------- .../vision/test_object_detection_map.py | 2 +- 5 files changed, 155 insertions(+), 221 deletions(-) rename ignite/metrics/vision/{object_detection_map.py => object_detection_average_precision_recall.py} (75%) diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index d5d2bd56078..4925a558dd1 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -24,7 +24,7 @@ from ignite.metrics.running_average import RunningAverage from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy -from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP +from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionMAP __all__ = [ "Metric", diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index e98666ba8d5..caa633e41b6 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -7,23 +7,24 @@ import ignite.distributed as idist from ignite.distributed.utils import all_gather_tensors_with_shapes from ignite.metrics.metric import reinit__is_reduced -from ignite.metrics.recall import _BasePrecisionRecall +from ignite.metrics.precision import _BaseClassification +from ignite.metrics import Metric from ignite.utils import to_onehot -class _BaseMeanAveragePrecision(_BasePrecisionRecall): +class _BaseAveragePrecision: def __init__( self, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - average: Optional[Literal["precision", "max-precision"]] = "precision", - class_mean: Optional[Literal["micro", "macro", "weighted", "with_other_dims"]] = "macro", - is_multilabel: bool = False, - output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + class_mean: Optional[Literal["micro", "macro", "weighted"]] = "macro", ) -> None: - r"""Base class for Mean Average Precision in classification and detection tasks. + r"""Base class for Average Precision & Recall in classification and detection tasks. + + This class contains the methods for setting up the thresholds and computing AP & AR. + + # Average precision is computed by averaging precision over increasing levels of recall thresholds as following: + - Mean average precision is computed by taking the mean of the average precision over different classes and possibly some additional dimensions in the detection task. ``class_mean`` determines how to take this mean. In the detection tasks, it's possible to take the mean in other respects as well e.g. IoU threshold in an object detection task. @@ -32,22 +33,6 @@ def __init__( rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need to be sorted. If missing, thresholds are considered automatically using the data. - average: one of values "precision" or "max-precision". In the former case, the precision at a - recall threshold is used for that threshold: - - .. math:: - \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) P_k - - :math:`r` stands for recall thresholds and :math:`P` for precision values. :math:`r_0` is set to zero. - - In the latter case, the maximum precision across thresholds greater or equal a recall threshold is - considered as the summation operand; In other words, the precision peek across lower or equal - sensivity levels is used for a recall threshold: - - .. math:: - \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) max(P_{k:}) - - Default is "precision". class_mean: how to compute mean of the average precision across classes or incorporate class dimension into computing precision. It's ignored in binary classification. Available options are @@ -78,38 +63,18 @@ def __init__( 'macro' computes macro precision which is unweighted mean of AP computed across classes/labels. Default. - 'with_other_dims' - Mean over class dimension is taken with additional mean dimensions all at once, despite macro and - weighted in which mean over additional dimensions is taken beforehand. Only available in detection. - Note: Please note that classes with no ground truth are not considered into the mean in detection. - - is_multilabel: Used in classification task and determines if the data - is multilabel or not. Default False. - output_transform: a callable that is used to transform the - :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the - form expected by the metric. This can be useful if, for example, you have a multi-output model and - you want to compute the metric with respect to one of the outputs. This metric requires the output - as ``(y_pred, y)``. - device: specifies which device updates are accumulated on. Setting the - metric's device to be the same as your ``update`` arguments ensures the ``update`` method is - non-blocking. By default, CPU. """ if rec_thresholds is not None: self.rec_thresholds: Optional[torch.Tensor] = self._setup_thresholds(rec_thresholds, "rec_thresholds") else: self.rec_thresholds = None - if average not in ("precision", "max-precision"): - raise ValueError(f"Wrong `average` parameter, given {average}") - self.average = average - if class_mean is not None and class_mean not in ("micro", "macro", "weighted", "with_other_dims"): raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") self.class_mean = class_mean - super().__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device) def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor: if isinstance(thresholds, Sequence): @@ -134,44 +99,58 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens """Measuring average precision which is the common operation among different settings of the metric. Args: - recall: n-dimensional tensor whose last dimension is the dimension of the samples. Should be ordered in - ascending order in its last dimension. + recall: n-dimensional tensor whose last dimension represents confidence thresholds as much as #samples. + Should be ordered in ascending order in its last dimension. precision: like ``recall`` in the shape. Returns: average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ - precision_integrand = ( - precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision - ) if self.rec_thresholds is not None: rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1)) rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) - precision_integrand = precision_integrand.take_along_dim( + precision = precision.take_along_dim( rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 ).where(rec_thresh_indices != recall.size(-1), 0) recall = rec_thresholds recall_differential = recall.diff( dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=self._device, dtype=torch.double) ) - return torch.sum(recall_differential * precision_integrand, dim=-1) + return torch.sum(recall_differential * precision, dim=-1) -class MeanAveragePrecision(_BaseMeanAveragePrecision): +def _cat_and_agg_tensors(tensors: List[torch.Tensor], tensor_shape: Tuple[int], num_preds: List[int], dtype: torch.dtype, device: Union[str, torch.device]) -> torch.Tensor: + tensor = torch.cat(tensors, dim=-1) if tensors else torch.empty((*tensor_shape, 0), dtype=dtype, device=device) + shape_across_ranks = [ + (*tensor_shape, num_pred_in_rank) for num_pred_in_rank in num_preds + ] + return torch.cat( + all_gather_tensors_with_shapes( + tensor, + shape_across_ranks, + ), + dim=-1, + ) + + +class MeanAveragePrecision(_BaseClassification, _BaseAveragePrecision): + _scores: List[torch.Tensor] _P: List[torch.Tensor] def __init__( self, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - average: Optional["Literal['precision', 'max-precision']"] = "precision", class_mean: Optional["Literal['micro', 'macro', 'weighted']"] = "macro", is_multilabel: bool = False, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: r"""Calculate the mean average precision metric i.e. mean of the averaged-over-recall precision for - classification task. + classification task: + + .. math:: + \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) P_k Mean average precision attempts to give a measure of detector or classifier precision at various sensivity levels a.k.a recall thresholds. This is done by summing precisions at different recall @@ -190,27 +169,11 @@ def __init__( rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need to be sorted. If missing, thresholds are considered automatically using the data. - average: one of values "precision" or "max-precision". In the former case, the precision at a - recall threshold is used for that threshold: - - .. math:: - \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) P_k - - :math:`r` stands for recall thresholds and :math:`P` for precision values. :math:`r_0` is set to zero. - - In the latter case, the maximum precision across thresholds greater or equal a recall threshold is - considered as the summation operand; In other words, the precision peek across lower or equal - sensivity levels is used for a recall threshold: - - .. math:: - \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) max(P_{k:}) - - Default is "precision". class_mean: how to compute mean of the average precision across classes or incorporate class dimension into computing precision. It's ignored in binary classification. Available options are None - An 1-dimensional tensor of mean (taken across additional mean dimensions) average precision per class + A 1-dimensional tensor of mean (taken across additional mean dimensions) average precision per class is returned. If there's no ground truth sample for a class, ``0`` is returned for that. 'micro' @@ -247,24 +210,19 @@ def __init__( non-blocking. By default, CPU. """ - super().__init__( - rec_thresholds=rec_thresholds, - average=average, - class_mean=class_mean, + super(MeanAveragePrecision, self).__init__( output_transform=output_transform, is_multilabel=is_multilabel, device=device, ) - - if self.class_mean == "with_other_dims": - raise ValueError("class_mean 'with_other_dims' is not compatible with this class.") + super(Metric, self).__init__(rec_thresholds=rec_thresholds, class_mean=class_mean) @reinit__is_reduced def reset(self) -> None: """ Reset method of the metric """ - super(_BasePrecisionRecall, self).reset() + super().reset() self._scores = [] self._P = [] @@ -275,7 +233,7 @@ def _check_binary_multilabel_cases(self, output: Sequence[torch.Tensor]) -> None raise ValueError("For binary cases, y must be comprised of 0's and 1's.") def _check_type(self, output: Sequence[torch.Tensor]) -> None: - super(_BasePrecisionRecall, self)._check_type(output) + super()._check_type(output) y_pred, y = output if y_pred.dtype in (torch.int, torch.long): raise TypeError(f"`y_pred` should be a float tensor, given {y_pred.dtype}") @@ -393,33 +351,27 @@ def compute(self) -> Union[torch.Tensor, float]: raise RuntimeError("Metric could not be computed without any update method call") num_classes = self._num_classes - rank_P = ( - torch.cat(self._P, dim=-1) - if self._P - else ( - torch.empty((num_classes, 0), dtype=torch.uint8, device=self._device) - if self._type == "multilabel" - else torch.tensor( - [], dtype=torch.long if self._type == "multiclass" else torch.uint8, device=self._device - ) - ) + num_samples = torch.tensor( + [sum([p.shape[-1] for p in self._P]) if self._P else 0], + device=self._device, ) - rank_P_shapes = cast(torch.Tensor, idist.all_gather(torch.tensor(rank_P.shape))).view(-1, len(rank_P.shape)) - P = torch.cat(all_gather_tensors_with_shapes(rank_P, rank_P_shapes.tolist()), dim=-1) - - rank_scores = ( - torch.cat(self._scores, dim=-1) - if self._scores - else ( - torch.tensor([], device=self._device) - if self._type == "binary" - else torch.empty((num_classes, 0), dtype=torch.double, device=self._device) - ) + num_samples_across_ranks = cast(torch.Tensor, idist.all_gather(num_samples)).tolist() + + P = _cat_and_agg_tensors( + self._P, + (num_classes,) if self._type == "multiabel" else (), + num_samples_across_ranks, + torch.long if self._type == "multiclass" else torch.uint8, + self._device ) - rank_scores_shapes = cast(torch.Tensor, idist.all_gather(torch.tensor(rank_scores.shape))).view( - -1, len(rank_scores.shape) + + scores = _cat_and_agg_tensors( + self._scores, + (num_classes,) if self._type != "binary" else (), + num_samples_across_ranks, + torch.double, + self._device ) - scores = torch.cat(all_gather_tensors_with_shapes(rank_scores, rank_scores_shapes.tolist()), dim=-1) if self._type == "multiclass": P = to_onehot(P, num_classes=num_classes).T diff --git a/ignite/metrics/vision/__init__.py b/ignite/metrics/vision/__init__.py index f351d5b339f..60463a46aff 100644 --- a/ignite/metrics/vision/__init__.py +++ b/ignite/metrics/vision/__init__.py @@ -1,3 +1,3 @@ -from ignite.metrics.vision.object_detection_map import ObjectDetectionMAP +from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionMAP __all__ = ["ObjectDetectionMAP"] diff --git a/ignite/metrics/vision/object_detection_map.py b/ignite/metrics/vision/object_detection_average_precision_recall.py similarity index 75% rename from ignite/metrics/vision/object_detection_map.py rename to ignite/metrics/vision/object_detection_average_precision_recall.py index 44fdb1a38db..53ba9ad42a8 100644 --- a/ignite/metrics/vision/object_detection_map.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -1,14 +1,13 @@ from collections import defaultdict -from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union, override import torch from typing_extensions import Literal import ignite.distributed as idist -from ignite.distributed.utils import all_gather_tensors_with_shapes -from ignite.metrics.mean_average_precision import _BaseMeanAveragePrecision +from ignite.metrics.mean_average_precision import _BaseAveragePrecision, _cat_and_agg_tensors -from ignite.metrics.metric import reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced from ignite.metrics.recall import _BasePrecisionRecall @@ -43,11 +42,13 @@ def tensor_list_to_dict_list( return cast(Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]], (y_pred, y)) -class ObjectDetectionMAP(_BaseMeanAveragePrecision): +class ObjectDetectionAvgPrecisionRecall(Metric, _BaseAveragePrecision): + _tp: Dict[int, List[torch.Tensor]] _fp: Dict[int, List[torch.Tensor]] _scores: Dict[int, List[torch.Tensor]] _P: Dict[int, int] + _num_classes: int def __init__( self, @@ -58,13 +59,22 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: - """Calculate the mean average precision for evaluating an object detector. + """Calculate the mean average precision & recall for evaluating an object detector. + + In average precision, the maximum precision across thresholds greater or equal a recall threshold is + considered as the summation operand; In other words, the precision peek across lower or equal + sensivity levels is used for a recall threshold: + + .. math:: + \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) max(P_{k:}) Args: - iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision. + iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision & recall. Values should be between 0 and 1. If not given, COCO's default (.5, .55, ..., .95) would be used. rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, COCO's default (.0, .01, .02, ..., 1.) would be used. + max_detections_per_image: Max number of detections in each image to consider for evaluation. The most + confident ones are selected. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. An already provided example is :func:`~ignite.metrics.vision.object_detection_map.tensor_list_to_dict_list` @@ -99,21 +109,17 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo self.area_range = area_range self.max_detections_per_image = max_detections_per_image - super().__init__( - rec_thresholds=rec_thresholds, - average="max-precision", - class_mean="with_other_dims", + super(ObjectDetectionAvgPrecisionRecall, self).__init__( output_transform=output_transform, device=device, ) + super(Metric, self).__init__( + rec_thresholds=rec_thresholds, + class_mean=None, + ) @reinit__is_reduced def reset(self) -> None: - """ - Reset method of the metric - """ - super(_BasePrecisionRecall, self).reset() - self._tp = defaultdict(lambda: []) self._fp = defaultdict(lambda: []) self._scores = defaultdict(lambda: []) @@ -166,10 +172,10 @@ def _compute_recall_and_precision( ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Measuring recall & precision - This method is overriden since in the pycocotools reference implementation, when there are predictions with the - same scores, they're considered associated with different thresholds in the course of measuring recall - values, although it's not logically correct as those predictions are really associated with a single threshold, - thus outputing a single recall value. + This method is different from that of MeanAveragePrecision since in the pycocotools reference implementation, + when there are predictions with the same scores, they're considered associated with different thresholds in the + course of measuring recall values, although it's not logically correct as those predictions are really + associated with a single threshold, thus outputing a single recall value. Shape of function inputs and return values follow the table below. N\ :sub:`pred` is the number of detections or predictions. ``...`` stands for the possible additional dimensions. Finally, \#unique scores represents @@ -201,6 +207,7 @@ def _compute_recall_and_precision( return recall, precision + @override def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: """Measuring average precision. This method is overriden since :math:`1/#recall_thresholds` is used instead of :math:`r_k - r_{k-1}` @@ -214,9 +221,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens Returns: average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ - precision_integrand = ( - precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision - ) + precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) precision_integrand = precision_integrand.take_along_dim( @@ -228,7 +233,7 @@ def _do_matching( self, pred: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor] ) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]: r""" - Matching logic of object detection mAP. + Matching logic of object detection mAP, according to COCO reference implementation. The method returns a quadrople of dictionaries containing TP, FP, P (actual positive) counts and scores for each class respectively. Please note that class numbers start from zero. @@ -268,6 +273,7 @@ def _do_matching( pred_scores = pred["scores"][best_detections_index] pred_labels = pred["labels"][best_detections_index] pred_boxes = pred["bbox"][best_detections_index] + pred_match_area_range = self._match_area_range(pred_boxes) categories = list(set(labels.int().tolist() + pred_labels.int().tolist())) @@ -276,6 +282,23 @@ def _do_matching( P: Dict[int, int] = {} scores: Dict[int, torch.Tensor] = {} + ious = self.box_iou(pred_boxes, gt_boxes, cast(torch.BoolTensor, category_is_crowd)) + + NO_MATCH = -3 + ious[:, gt_ignore] -= 2 + category_no_match = labels.expand(len(pred_labels), -1) != labels.view(-1, 1) + ious[category_no_match] = NO_MATCH + ious.unsqueeze(-1).repeat((1, 1, len(self.iou_thresholds))) + ious[ious < self.iou_thresholds] = NO_MATCH + for i in range(len(pred_labels)): + # Flip is done to give priority to the last item with maximal value, as torch.max selects the first one. + match_gts = ious[i].flip(0).max(0) + match_gts_indices = ious.size(1) -1 - match_gts.indices + for t in range(len(self.iou_thresholds)): + if match_gts.values[t] != NO_MATCH and not gt_is_crowd[match_gts_indices[t]]: + ious[:, match_gts_indices[t], t] = NO_MATCH + ious[i, match_gts_indices[t], t] = match_gts.values[t] + for category in categories: category_index_gt = labels == category num_category_gt = category_index_gt.sum() @@ -283,59 +306,15 @@ def _do_matching( category_gt_ignore = gt_ignore[category_index_gt] if num_category_gt: # what if P[c] becomes 0 ? P[category] = num_category_gt - category_gt_ignore.sum() - + category_index_dt = pred_labels == category if not category_index_dt.any(): continue scores[category] = pred_scores[category_index_dt] - pred_match_area_range = self._match_area_range(pred_boxes[category_index_dt]) - num_category_dt = category_index_dt.sum().item() - category_tp = torch.zeros( - (len(self.iou_thresholds), num_category_dt), dtype=torch.uint8, device=self._device - ) - category_fp = torch.zeros( - (len(self.iou_thresholds), num_category_dt), dtype=torch.uint8, device=self._device - ) - if num_category_gt: - category_iou = self.box_iou( - pred_boxes[category_index_dt], - gt_boxes[category_index_gt], - cast(torch.BoolTensor, category_is_crowd), - ) - category_maximum_iou = category_iou.max() - for thres_idx, iou_thres in enumerate(self.iou_thresholds): - if iou_thres <= category_maximum_iou: - matched_gt_indices = set() - for pred_idx in range(num_category_dt): - match_iou, match_idx = min(iou_thres, 1 - 1e-10), -1 - for gt_idx in range(num_category_gt): - if (category_iou[pred_idx, gt_idx] < iou_thres) or ( - gt_idx in matched_gt_indices and ~category_is_crowd[gt_idx] - ): - continue - if ( - match_idx == -1 - or (category_gt_ignore[match_idx] & ~category_gt_ignore[gt_idx]) - or ( - (category_gt_ignore[match_idx] | ~category_gt_ignore[gt_idx]) - and category_iou[pred_idx][gt_idx] >= match_iou - ) - ): - match_iou = category_iou[pred_idx][gt_idx] - match_idx = gt_idx - if match_idx != -1: - matched_gt_indices.add(match_idx) - category_tp[thres_idx][pred_idx] = ~category_gt_ignore[match_idx] - else: - category_fp[thres_idx][pred_idx] = pred_match_area_range[pred_idx] - else: - category_fp[thres_idx] = pred_match_area_range - else: - category_fp[:, :] = pred_match_area_range - - tp[category] = category_tp - fp[category] = category_fp + category_max_ious = ious[category_index_dt].max(1).values + tp[category] = (category_max_ious >= 0).T.to(dtype=torch.uint8, device=self._device) + fp[category] = ((category_max_ious == NO_MATCH).T and pred_match_area_range[category_index_dt]).to(dtype=torch.uint8, device=self._device) return tp, fp, P, scores @@ -385,10 +364,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor if classes: self._num_classes = max(max(classes) + 1, self._num_classes) - def compute(self) -> float: - """ - Compute method of the metric - """ + def _compute(self) -> torch.Tensor: num_classes = int(idist.all_reduce(self._num_classes or 0, "MAX")) if num_classes < 1: return 0.0 @@ -423,8 +399,8 @@ def compute(self) -> float: ), ).tolist() - average_precisions = -torch.ones( - (num_classes, *mean_dimensions_shape), + average_precisions_recalls = -torch.ones( + (2, num_classes, *mean_dimensions_shape), device=self._device, dtype=torch.double, ) @@ -432,41 +408,47 @@ def compute(self) -> float: if P[cls] == 0: continue - num_preds_across_ranks = num_preds_per_class_across_ranks[:, [cls]] - if num_preds_across_ranks.sum() == 0: - average_precisions[cls] = 0 + num_preds_across_ranks = num_preds_per_class_across_ranks[:, cls].tolist() + if sum(num_preds_across_ranks) == 0: + average_precisions_recalls[0, cls] = 0 continue - shape_across_ranks = [ - (*mean_dimensions_shape, num_pred_in_rank.item()) for num_pred_in_rank in num_preds_across_ranks - ] - TP = torch.cat( - all_gather_tensors_with_shapes( - torch.cat(self._tp[cls], dim=-1) - if self._tp[cls] - else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), - shape_across_ranks, - ), - dim=-1, - ) - FP = torch.cat( - all_gather_tensors_with_shapes( - torch.cat(self._fp[cls], dim=-1) - if self._fp[cls] - else torch.empty((*mean_dimensions_shape, 0), dtype=torch.uint8, device=self._device), - shape_across_ranks, - ), - dim=-1, - ) - scores = torch.cat( - all_gather_tensors_with_shapes( - torch.cat(self._scores[cls]) - if self._scores[cls] - else torch.tensor([], dtype=torch.double, device=self._device), - num_preds_across_ranks.tolist(), - ) - ) + TP = _cat_and_agg_tensors(self._tp[cls], mean_dimensions_shape, num_preds_across_ranks, torch.uint8, self._device) + FP = _cat_and_agg_tensors(self._fp[cls], mean_dimensions_shape, num_preds_across_ranks, torch.uint8, self._device) + scores = _cat_and_agg_tensors(self._scores[cls], (), num_preds_across_ranks, torch.double, self._device) + recall, precision = self._compute_recall_and_precision(TP, FP, scores, P[cls]) average_precision_for_cls_across_other_dims = self._compute_average_precision(recall, precision) - average_precisions[cls] = average_precision_for_cls_across_other_dims - - return average_precisions[average_precisions > -1].mean().item() + average_precisions_recalls[0, cls] = average_precision_for_cls_across_other_dims + + average_precisions_recalls[1, cls] = recall[..., -1] + + return average_precisions_recalls + + def compute(self) -> Tuple[float, float]: + average_precisions_recalls = self._compute() + ap = average_precisions_recalls[0][average_precisions_recalls[0] > -1].mean().item() + ar = average_precisions_recalls[1][average_precisions_recalls[1] > -1].mean().item() + + return ap, ar + + +# class ObjDetCommonAPandAR(Metric): +# """ +# Computes following common variants of average precision (AP) and recall (AR). + +# ============== ====================== +# Metric variant Description +# ============== ====================== +# AP@.5...95 (..., N\ :sub:`pred`) +# AP@.5 (N\ :sub:`pred`,) +# AP@.75 () (A single float, +# AP-S greater than zero) +# AP-M +# AP-L (..., \#unique scores) +# AR-S +# AR-M +# AR-L (..., \#unique scores) +# ============== ====================== +# """ +# def __init__(self, output_transform: Callable = lambda x:x, device: Union[str, torch.device] = torch.device("cpu")): +# super().__init__(output_transform, device) \ No newline at end of file diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 9abd9214b12..5b8c9ebea0c 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -18,7 +18,7 @@ import ignite.distributed as idist from ignite.engine import Engine from ignite.metrics import ObjectDetectionMAP -from ignite.metrics.vision.object_detection_map import tensor_list_to_dict_list +from ignite.metrics.vision.object_detection_average_precision_recall import tensor_list_to_dict_list from ignite.utils import manual_seed torch.set_printoptions(linewidth=200) From cf02dc0fd636b929ea50122d76b606329fcc4f53 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Thu, 11 Jul 2024 10:43:50 +0330 Subject: [PATCH 18/41] Improvement in code --- ignite/metrics/__init__.py | 2 +- ignite/metrics/mean_average_precision.py | 21 ++- ignite/metrics/vision/__init__.py | 4 +- ...ject_detection_average_precision_recall.py | 131 +++++++----------- .../metrics/test_mean_average_precision.py | 6 - .../vision/test_object_detection_map.py | 65 +++++---- 6 files changed, 96 insertions(+), 133 deletions(-) diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 4925a558dd1..e01094d5281 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -24,7 +24,7 @@ from ignite.metrics.running_average import RunningAverage from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy -from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionMAP +from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionAvgPrecisionRecall __all__ = [ "Metric", diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index caa633e41b6..84e10329af2 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -8,7 +8,7 @@ from ignite.distributed.utils import all_gather_tensors_with_shapes from ignite.metrics.metric import reinit__is_reduced from ignite.metrics.precision import _BaseClassification -from ignite.metrics import Metric +from ignite.metrics.metric import Metric from ignite.utils import to_onehot @@ -119,10 +119,15 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens return torch.sum(recall_differential * precision, dim=-1) -def _cat_and_agg_tensors(tensors: List[torch.Tensor], tensor_shape: Tuple[int], num_preds: List[int], dtype: torch.dtype, device: Union[str, torch.device]) -> torch.Tensor: - tensor = torch.cat(tensors, dim=-1) if tensors else torch.empty((*tensor_shape, 0), dtype=dtype, device=device) +def _cat_and_agg_tensors(tensors: List[torch.Tensor], tensor_shape_except_last_dim: Tuple[int], dtype: torch.dtype, device: Union[str, torch.device]) -> torch.Tensor: + num_preds = torch.tensor( + [sum([tensor.shape[-1] for tensor in tensors]) if tensors else 0], + device=device, + ) + num_preds = cast(torch.Tensor, idist.all_gather(num_preds)).tolist() + tensor = torch.cat(tensors, dim=-1) if tensors else torch.empty((*tensor_shape_except_last_dim, 0), dtype=dtype, device=device) shape_across_ranks = [ - (*tensor_shape, num_pred_in_rank) for num_pred_in_rank in num_preds + (*tensor_shape_except_last_dim, num_pred_in_rank) for num_pred_in_rank in num_preds ] return torch.cat( all_gather_tensors_with_shapes( @@ -351,16 +356,9 @@ def compute(self) -> Union[torch.Tensor, float]: raise RuntimeError("Metric could not be computed without any update method call") num_classes = self._num_classes - num_samples = torch.tensor( - [sum([p.shape[-1] for p in self._P]) if self._P else 0], - device=self._device, - ) - num_samples_across_ranks = cast(torch.Tensor, idist.all_gather(num_samples)).tolist() - P = _cat_and_agg_tensors( self._P, (num_classes,) if self._type == "multiabel" else (), - num_samples_across_ranks, torch.long if self._type == "multiclass" else torch.uint8, self._device ) @@ -368,7 +366,6 @@ def compute(self) -> Union[torch.Tensor, float]: scores = _cat_and_agg_tensors( self._scores, (num_classes,) if self._type != "binary" else (), - num_samples_across_ranks, torch.double, self._device ) diff --git a/ignite/metrics/vision/__init__.py b/ignite/metrics/vision/__init__.py index 60463a46aff..b5b5d0236bb 100644 --- a/ignite/metrics/vision/__init__.py +++ b/ignite/metrics/vision/__init__.py @@ -1,3 +1,3 @@ -from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionMAP +from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionAvgPrecisionRecall -__all__ = ["ObjectDetectionMAP"] +__all__ = ["ObjectDetectionAvgPrecisionRecall"] diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 53ba9ad42a8..8a5abd106b1 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -44,9 +44,10 @@ def tensor_list_to_dict_list( class ObjectDetectionAvgPrecisionRecall(Metric, _BaseAveragePrecision): - _tp: Dict[int, List[torch.Tensor]] - _fp: Dict[int, List[torch.Tensor]] - _scores: Dict[int, List[torch.Tensor]] + _tps: List[torch.Tensor] + _fps: List[torch.Tensor] + _scores: List[torch.Tensor] + _pred_labels: List[torch.Tensor] _P: Dict[int, int] _num_classes: int @@ -120,9 +121,10 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo @reinit__is_reduced def reset(self) -> None: - self._tp = defaultdict(lambda: []) - self._fp = defaultdict(lambda: []) - self._scores = defaultdict(lambda: []) + self._tps = [] + self._fps = [] + self._scores = [] + self._pred_labels = [] self._P = defaultdict(lambda: 0) self._num_classes: int = 0 @@ -231,7 +233,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens def _do_matching( self, pred: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor] - ) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]: + ) -> Tuple[Union[torch.Tensor,None], Union[torch.Tensor,None], Dict[int, int], torch.Tensor, torch.Tensor]: r""" Matching logic of object detection mAP, according to COCO reference implementation. @@ -277,46 +279,38 @@ def _do_matching( categories = list(set(labels.int().tolist() + pred_labels.int().tolist())) - tp: Dict[int, torch.Tensor] = {} - fp: Dict[int, torch.Tensor] = {} P: Dict[int, int] = {} - scores: Dict[int, torch.Tensor] = {} - - ious = self.box_iou(pred_boxes, gt_boxes, cast(torch.BoolTensor, category_is_crowd)) - - NO_MATCH = -3 - ious[:, gt_ignore] -= 2 - category_no_match = labels.expand(len(pred_labels), -1) != labels.view(-1, 1) - ious[category_no_match] = NO_MATCH - ious.unsqueeze(-1).repeat((1, 1, len(self.iou_thresholds))) - ious[ious < self.iou_thresholds] = NO_MATCH - for i in range(len(pred_labels)): - # Flip is done to give priority to the last item with maximal value, as torch.max selects the first one. - match_gts = ious[i].flip(0).max(0) - match_gts_indices = ious.size(1) -1 - match_gts.indices - for t in range(len(self.iou_thresholds)): - if match_gts.values[t] != NO_MATCH and not gt_is_crowd[match_gts_indices[t]]: - ious[:, match_gts_indices[t], t] = NO_MATCH - ious[i, match_gts_indices[t], t] = match_gts.values[t] - for category in categories: category_index_gt = labels == category num_category_gt = category_index_gt.sum() - category_is_crowd = gt_is_crowd[category_index_gt] category_gt_ignore = gt_ignore[category_index_gt] if num_category_gt: # what if P[c] becomes 0 ? P[category] = num_category_gt - category_gt_ignore.sum() - - category_index_dt = pred_labels == category - if not category_index_dt.any(): - continue - - scores[category] = pred_scores[category_index_dt] - category_max_ious = ious[category_index_dt].max(1).values - tp[category] = (category_max_ious >= 0).T.to(dtype=torch.uint8, device=self._device) - fp[category] = ((category_max_ious == NO_MATCH).T and pred_match_area_range[category_index_dt]).to(dtype=torch.uint8, device=self._device) + + if len(pred_labels): + ious = self.box_iou(pred_boxes, gt_boxes, cast(torch.BoolTensor, gt_is_crowd)) + NO_MATCH = -3 + ious[:, gt_ignore] -= 2 + category_no_match = labels.expand(len(pred_labels), -1) != pred_labels.view(-1, 1) + ious[category_no_match] = NO_MATCH + ious.unsqueeze(-1).repeat((1, 1, len(self.iou_thresholds))) + ious[ious < self.iou_thresholds] = NO_MATCH + for i in range(len(pred_labels)): + # Flip is done to give priority to the last item with maximal value, as torch.max selects the first one. + match_gts = ious[i].flip(0).max(0) + match_gts_indices = ious.size(1) -1 - match_gts.indices + for t in range(len(self.iou_thresholds)): + if match_gts.values[t] != NO_MATCH and not gt_is_crowd[match_gts_indices[t]]: + ious[:, match_gts_indices[t], t] = NO_MATCH + ious[i, match_gts_indices[t], t] = match_gts.values[t] + + max_ious = ious.max(1).values + tp = (max_ious >= 0).T.to(dtype=torch.uint8, device=self._device) + fp = ((max_ious == NO_MATCH).T and pred_match_area_range).to(dtype=torch.uint8, device=self._device) + else: + tp = fp = None - return tp, fp, P, scores + return tp, fp, P, pred_scores, pred_labels @reinit__is_reduced def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]) -> None: @@ -353,18 +347,19 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor """ self._check_matching_input(output) for y_pred, y in zip(*output): - tps, fps, ps, scores_dict = self._do_matching(y_pred, y) - for cls in tps: - self._tp[cls].append(tps[cls].to(device=self._device, dtype=torch.uint8)) - self._fp[cls].append(fps[cls].to(device=self._device, dtype=torch.uint8)) - self._scores[cls].append(scores_dict[cls].to(self._device)) + tp, fp, ps, scores, pred_labels = self._do_matching(y_pred, y) + if tp is not None: + self._tps.append(tp.to(device=self._device, dtype=torch.uint8)) + self._fps.append(fp.to(device=self._device, dtype=torch.uint8)) + self._scores.append(scores.to(self._device)) + self._pred_labels.append(pred_labels.to(device=self._device)) for cls in ps: self._P[cls] += ps[cls] - classes = tps.keys() | ps.keys() + classes = set(pred_labels.tolist()) | ps.keys() if classes: self._num_classes = max(max(classes) + 1, self._num_classes) - def _compute(self) -> torch.Tensor: + def _compute(self) -> Union[torch.Tensor, float]: num_classes = int(idist.all_reduce(self._num_classes or 0, "MAX")) if num_classes < 1: return 0.0 @@ -374,54 +369,32 @@ def _compute(self) -> torch.Tensor: idist.all_reduce(torch.tensor(list(map(self._P.__getitem__, range(num_classes))), device=self._device)), ) if P.sum() < 1: - return -1 + return -1. - num_preds = torch.tensor( - [sum([tp.shape[-1] for tp in self._tp[cls]]) if self._tp[cls] else 0 for cls in range(num_classes)], - device=self._device, - ) - num_preds_per_class_across_ranks = torch.stack( - cast(torch.Tensor, idist.all_gather(num_preds)).split(split_size=num_classes) - ) - if num_preds_per_class_across_ranks.sum() == 0: - return 0.0 - a_nonempty_rank, a_nonempty_class = list(zip(*torch.where(num_preds_per_class_across_ranks != 0))).pop(0) - a_nonempty_rank = a_nonempty_rank.item() - a_nonempty_class = a_nonempty_class.item() - mean_dimensions_shape = cast( - torch.Tensor, - idist.broadcast( - torch.tensor(self._tp[a_nonempty_class][-1].shape[:-1], device=self._device) - if idist.get_rank() == a_nonempty_rank - else None, - a_nonempty_rank, - safe_mode=True, - ), - ).tolist() + pred_labels = _cat_and_agg_tensors(self._pred_labels, (), torch.long, self._device) + TP = _cat_and_agg_tensors(self._tps, (len(self.iou_thresholds),), torch.uint8, self._device) + FP = _cat_and_agg_tensors(self._fps, (len(self.iou_thresholds),), torch.uint8, self._device) + scores = _cat_and_agg_tensors(self._scores, (), torch.double, self._device) average_precisions_recalls = -torch.ones( - (2, num_classes, *mean_dimensions_shape), + (2, num_classes, len(self.iou_thresholds)), device=self._device, dtype=torch.double, ) for cls in range(num_classes): if P[cls] == 0: continue - - num_preds_across_ranks = num_preds_per_class_across_ranks[:, cls].tolist() - if sum(num_preds_across_ranks) == 0: - average_precisions_recalls[0, cls] = 0 + + cls_labels = pred_labels == cls + if sum(cls_labels) == 0: + average_precisions_recalls[0, cls] = 0. continue - TP = _cat_and_agg_tensors(self._tp[cls], mean_dimensions_shape, num_preds_across_ranks, torch.uint8, self._device) - FP = _cat_and_agg_tensors(self._fp[cls], mean_dimensions_shape, num_preds_across_ranks, torch.uint8, self._device) - scores = _cat_and_agg_tensors(self._scores[cls], (), num_preds_across_ranks, torch.double, self._device) - recall, precision = self._compute_recall_and_precision(TP, FP, scores, P[cls]) + recall, precision = self._compute_recall_and_precision(TP[..., cls_labels], FP[..., cls_labels], scores[cls_labels], P[cls]) average_precision_for_cls_across_other_dims = self._compute_average_precision(recall, precision) average_precisions_recalls[0, cls] = average_precision_for_cls_across_other_dims average_precisions_recalls[1, cls] = recall[..., -1] - return average_precisions_recalls def compute(self) -> Tuple[float, float]: diff --git a/tests/ignite/metrics/test_mean_average_precision.py b/tests/ignite/metrics/test_mean_average_precision.py index 3125aeb13e0..d66ac5bcd89 100644 --- a/tests/ignite/metrics/test_mean_average_precision.py +++ b/tests/ignite/metrics/test_mean_average_precision.py @@ -18,18 +18,12 @@ def test_wrong_input(): with pytest.raises(TypeError, match="rec_thresholds should be a sequence of floats or a tensor"): MeanAveragePrecision(rec_thresholds={0, 0.2, 0.4, 0.6, 0.8}) - with pytest.raises(ValueError, match="Wrong `average` parameter"): - MeanAveragePrecision(average=1) - with pytest.raises(ValueError, match="Wrong `class_mean` parameter"): MeanAveragePrecision(class_mean="samples") with pytest.raises(ValueError, match="rec_thresholds values should be between 0 and 1"): MeanAveragePrecision(rec_thresholds=(0.0, 0.5, 1.0, 1.5)) - with pytest.raises(ValueError, match="class_mean 'with_other_dims' is not compatible with this class"): - MeanAveragePrecision(class_mean="with_other_dims") - metric = MeanAveragePrecision() with pytest.raises(RuntimeError, match="Metric could not be computed without any update method call"): metric.compute() diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 5b8c9ebea0c..bdf88974f95 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -17,7 +17,7 @@ import ignite.distributed as idist from ignite.engine import Engine -from ignite.metrics import ObjectDetectionMAP +from ignite.metrics import ObjectDetectionAvgPrecisionRecall from ignite.metrics.vision.object_detection_average_precision_recall import tensor_list_to_dict_list from ignite.utils import manual_seed @@ -533,14 +533,14 @@ def create_coco_api( def pycoco_mAP(predictions: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]]) -> np.array: """ - Returned values are AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L + Returned values are AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L """ coco_dt, coco_gt = create_coco_api(predictions, targets) eval = COCOeval(coco_gt, coco_dt, iouType="bbox") eval.evaluate() eval.accumulate() eval.summarize() - return eval.stats[:6] + return eval.stats Sample = namedtuple("Sample", ["data", "mAP", "length"]) @@ -623,7 +623,7 @@ def sample(request) -> Sample: def test_wrong_input(): - m = ObjectDetectionMAP() + m = ObjectDetectionAvgPrecisionRecall() with pytest.raises(ValueError, match="y_pred and y should have the same number of samples"): m.update(([{"bbox": None, "scores": None}], [])) @@ -640,15 +640,15 @@ def test_empty_data(): Note that PyCOCO returns -1 when threre's no ground truth data. """ - metric = ObjectDetectionMAP() + metric = ObjectDetectionAvgPrecisionRecall() metric.update( ( [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}], [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))}], ) ) - assert len(metric._tp) == 0 - assert len(metric._fp) == 0 + assert len(metric._tps) == 0 + assert len(metric._fps) == 0 assert len(metric._P) == 0 assert metric._num_classes == 0 assert metric.compute() == 0 @@ -666,7 +666,7 @@ def test_empty_data(): ) assert metric.compute() == -1 - metric = ObjectDetectionMAP() + metric = ObjectDetectionAvgPrecisionRecall() metric.update( ( [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}], @@ -679,13 +679,13 @@ def test_empty_data(): ], ) ) - assert len(metric._tp) == 0 - assert len(metric._fp) == 0 + assert len(metric._tps) == 0 + assert len(metric._fps) == 0 assert len(metric._P) == 1 and metric._P[1] == 1 assert metric._num_classes == 2 assert metric.compute() == 0 - metric = ObjectDetectionMAP() + metric = ObjectDetectionAvgPrecisionRecall() pred = { "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), "scores": torch.tensor([0.9]), @@ -693,8 +693,7 @@ def test_empty_data(): } target = {"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))} metric.update(([pred], [target])) - assert (5 in metric._tp) and metric._tp[5][0].shape[1] == 1 - assert (5 in metric._fp) and metric._fp[5][0].shape[1] == 1 + assert len(metric._tps) == len(metric._fps) == 1 assert len(metric._P) == 0 assert metric._num_classes == 6 assert metric.compute() == pycoco_mAP([pred], [target])[0] @@ -703,11 +702,11 @@ def test_empty_data(): def test_no_torchvision(): with patch.dict(sys.modules, {"torchvision.ops.boxes": None}): with pytest.raises(ModuleNotFoundError, match=r"This metric requires torchvision to be installed."): - ObjectDetectionMAP() + ObjectDetectionAvgPrecisionRecall() def test_iou(sample): - m = ObjectDetectionMAP() + m = ObjectDetectionAvgPrecisionRecall() from pycocotools.mask import iou as pycoco_iou for pred, tgt in zip(*sample.data): @@ -728,7 +727,7 @@ def test_iou(sample): def test_iou_thresholding(): - metric = ObjectDetectionMAP(iou_thresholds=[0.0, 0.3, 0.5, 0.75]) + metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.0, 0.3, 0.5, 0.75]) pred = { "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), @@ -750,7 +749,7 @@ def test_iou_thresholding(): def test__do_matching_output(sample): - metric = ObjectDetectionMAP() + metric = ObjectDetectionAvgPrecisionRecall() for pred, target in zip(*sample.data): tps, fps, _, scores = metric._do_matching(pred, target) @@ -776,7 +775,7 @@ def test__do_matching_output(sample): assert metric_tp_or_fp[cls][-1].shape[:-1] == new_tp_or_fp[cls].shape[:-1] -class Dummy_mAP(ObjectDetectionMAP): +class Dummy_mAP(ObjectDetectionAvgPrecisionRecall): def _do_matching(self, tup1: Tuple, tup2: Tuple): tp, fp = tup1 p, score = tup2 @@ -819,7 +818,7 @@ def test_matching(): matched with a prediction in the sense that even if the crowd ground truth has a higher IOU, the non-crowd one gets matched if its IOU is viable. """ - metric = ObjectDetectionMAP(iou_thresholds=[0.2]) + metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.2]) pred = { "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]), @@ -891,7 +890,7 @@ def test__compute_recall_and_precision(): # The case in which detector detects all gt objects but also produces some wrong predictions. scores = torch.rand((50,)) y_true = torch.randint(0, 2, (50,)) - m = ObjectDetectionMAP() + m = ObjectDetectionAvgPrecisionRecall() ignite_recall, ignite_precision = m._compute_recall_and_precision( y_true.bool(), ~(y_true.bool()), scores, y_true.sum() @@ -925,12 +924,12 @@ def test__compute_recall_and_precision(): def test_compute(sample): device = idist.device() - metric_50_95 = ObjectDetectionMAP(device=device) - metric_50 = ObjectDetectionMAP(iou_thresholds=[0.5], device=device) - metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=device) - metric_S = ObjectDetectionMAP(device=device, area_range="small") - metric_M = ObjectDetectionMAP(device=device, area_range="medium") - metric_L = ObjectDetectionMAP(device=device, area_range="large") + metric_50_95 = ObjectDetectionAvgPrecisionRecall(device=device) + metric_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=device) + metric_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=device) + metric_S = ObjectDetectionAvgPrecisionRecall(device=device, area_range="small") + metric_M = ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium") + metric_L = ObjectDetectionAvgPrecisionRecall(device=device, area_range="large") metrics = [metric_50_95, metric_50, metric_75, metric_S, metric_M, metric_L] @@ -956,7 +955,7 @@ def update(engine, i): device = idist.device() metric_device = "cpu" if device.type == "xla" else device - metric_50_95 = ObjectDetectionMAP(device=metric_device) + metric_50_95 = ObjectDetectionAvgPrecisionRecall(device=metric_device) metric_50_95.attach(engine, name="mAP[50-95]") n_iter = ceil(sample.length / bs) @@ -1017,12 +1016,12 @@ def test_distrib_update_compute(distributed, sample): device = idist.device() metric_device = "cpu" if device.type == "xla" else device - metric_50_95 = ObjectDetectionMAP(device=metric_device) - metric_50 = ObjectDetectionMAP(iou_thresholds=[0.5], device=metric_device) - metric_75 = ObjectDetectionMAP(iou_thresholds=[0.75], device=metric_device) - metric_S = ObjectDetectionMAP(device=metric_device, area_range="small") - metric_M = ObjectDetectionMAP(device=metric_device, area_range="medium") - metric_L = ObjectDetectionMAP(device=metric_device, area_range="large") + metric_50_95 = ObjectDetectionAvgPrecisionRecall(device=metric_device) + metric_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=metric_device) + metric_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=metric_device) + metric_S = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="small") + metric_M = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="medium") + metric_L = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="large") metrics = [metric_50_95, metric_50, metric_75, metric_S, metric_M, metric_L] From 1593dfbb6d37c86bc170fe38ed8bf3b892d34b21 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 12 Jul 2024 14:25:29 +0330 Subject: [PATCH 19/41] Some improvements --- ignite/metrics/mean_average_precision.py | 2 +- ...ject_detection_average_precision_recall.py | 194 ++++++------------ .../vision/test_object_detection_map.py | 164 +++++++-------- 3 files changed, 147 insertions(+), 213 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index 84e10329af2..35eb66e76ba 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -71,7 +71,7 @@ def __init__( else: self.rec_thresholds = None - if class_mean is not None and class_mean not in ("micro", "macro", "weighted", "with_other_dims"): + if class_mean is not None and class_mean not in ("micro", "macro", "weighted"): raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") self.class_mean = class_mean diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 8a5abd106b1..e6abe776e8e 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -1,4 +1,3 @@ -from collections import defaultdict from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union, override import torch @@ -8,7 +7,6 @@ from ignite.metrics.mean_average_precision import _BaseAveragePrecision, _cat_and_agg_tensors from ignite.metrics.metric import Metric, reinit__is_reduced -from ignite.metrics.recall import _BasePrecisionRecall def tensor_list_to_dict_list( @@ -48,13 +46,14 @@ class ObjectDetectionAvgPrecisionRecall(Metric, _BaseAveragePrecision): _fps: List[torch.Tensor] _scores: List[torch.Tensor] _pred_labels: List[torch.Tensor] - _P: Dict[int, int] + _P: torch.Tensor _num_classes: int def __init__( self, iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + num_classes: Optional[int] = 91, max_detections_per_image: Optional[int] = 100, area_range: Optional[Literal["small", "medium", "large", "all"]] = "all", output_transform: Callable = lambda x: x, @@ -102,13 +101,14 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo if iou_thresholds is None: iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double) - self.iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds") + self._iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds") if rec_thresholds is None: rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=torch.double) - self.area_range = area_range - self.max_detections_per_image = max_detections_per_image + self._num_classes = num_classes + self._area_range = area_range + self._max_detections_per_image = max_detections_per_image super(ObjectDetectionAvgPrecisionRecall, self).__init__( output_transform=output_transform, @@ -125,23 +125,22 @@ def reset(self) -> None: self._fps = [] self._scores = [] self._pred_labels = [] - self._P = defaultdict(lambda: 0) - self._num_classes: int = 0 + self._P = torch.zeros((self._num_classes,), device=self._device) def _match_area_range(self, bboxes: torch.Tensor) -> torch.Tensor: from torchvision.ops.boxes import box_area areas = box_area(bboxes) - if self.area_range == "all": + if self._area_range == "all": min_area = 0 max_area = 1e10 - elif self.area_range == "small": + elif self._area_range == "small": min_area = 0 max_area = 1024 - elif self.area_range == "medium": + elif self._area_range == "medium": min_area = 1024 max_area = 9216 - elif self.area_range == "large": + elif self._area_range == "large": min_area = 9216 max_area = 1e10 return torch.logical_and(areas >= min_area, areas <= max_area) @@ -225,93 +224,12 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens """ precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) - rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) + rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) if recall.size(-1) != 0 else torch.LongTensor([], device=self._device) precision_integrand = precision_integrand.take_along_dim( rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 ).where(rec_thresh_indices != recall.size(-1), 0) return torch.sum(precision_integrand, dim=-1) / len(cast(torch.Tensor, self.rec_thresholds)) - def _do_matching( - self, pred: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor] - ) -> Tuple[Union[torch.Tensor,None], Union[torch.Tensor,None], Dict[int, int], torch.Tensor, torch.Tensor]: - r""" - Matching logic of object detection mAP, according to COCO reference implementation. - - The method returns a quadrople of dictionaries containing TP, FP, P (actual positive) counts and scores for - each class respectively. Please note that class numbers start from zero. - - Values in TP and FP are (m+1)-dimensional tensors of type ``uint8`` and shape - (D\ :sub:`1`, D\ :sub:`2`, ..., D\ :sub:`m`, n\ :sub:`cls`) in which D\ :sub:`i`\ 's are possible additional - dimensions (excluding the class dimension) mean of the average precision is taken over. n\ :sub:`cls` is the - number of predictions for class `cls` which is the same for TP and FP. - - Note: - TP and FP values are stored as uint8 tensors internally to avoid bool-to-uint8 copies before collective - operations, as PyTorch colective operations `do not `_ - support boolean tensors, at least on Gloo backend. - - P counts contains the number of ground truth samples for each class. Finally, the values in scores are 1-dim - tensors of shape (n\ :sub:`cls`,) containing score or confidence of the predictions (doesn't need to be in - [0,1]). If there is no prediction or ground truth for a class, it is absent from (TP, FP, scores) and P - dictionaries respectively. - - Args: - pred: First member of :meth:`update`'s input is given as this argument. - target: Second member of :meth:`update`'s input is given as this argument. - - Returns: - `(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores. - """ - labels = target["labels"] - gt_boxes = target["bbox"] - gt_is_crowd = ( - target["iscrowd"].bool() if "iscrowd" in target else torch.zeros_like(target["labels"], dtype=torch.bool) - ) - gt_ignore = ~self._match_area_range(gt_boxes) | gt_is_crowd - - best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True)[ - : self.max_detections_per_image - ] - pred_scores = pred["scores"][best_detections_index] - pred_labels = pred["labels"][best_detections_index] - pred_boxes = pred["bbox"][best_detections_index] - pred_match_area_range = self._match_area_range(pred_boxes) - - categories = list(set(labels.int().tolist() + pred_labels.int().tolist())) - - P: Dict[int, int] = {} - for category in categories: - category_index_gt = labels == category - num_category_gt = category_index_gt.sum() - category_gt_ignore = gt_ignore[category_index_gt] - if num_category_gt: # what if P[c] becomes 0 ? - P[category] = num_category_gt - category_gt_ignore.sum() - - if len(pred_labels): - ious = self.box_iou(pred_boxes, gt_boxes, cast(torch.BoolTensor, gt_is_crowd)) - NO_MATCH = -3 - ious[:, gt_ignore] -= 2 - category_no_match = labels.expand(len(pred_labels), -1) != pred_labels.view(-1, 1) - ious[category_no_match] = NO_MATCH - ious.unsqueeze(-1).repeat((1, 1, len(self.iou_thresholds))) - ious[ious < self.iou_thresholds] = NO_MATCH - for i in range(len(pred_labels)): - # Flip is done to give priority to the last item with maximal value, as torch.max selects the first one. - match_gts = ious[i].flip(0).max(0) - match_gts_indices = ious.size(1) -1 - match_gts.indices - for t in range(len(self.iou_thresholds)): - if match_gts.values[t] != NO_MATCH and not gt_is_crowd[match_gts_indices[t]]: - ious[:, match_gts_indices[t], t] = NO_MATCH - ious[i, match_gts_indices[t], t] = match_gts.values[t] - - max_ious = ious.max(1).values - tp = (max_ious >= 0).T.to(dtype=torch.uint8, device=self._device) - fp = ((max_ious == NO_MATCH).T and pred_match_area_range).to(dtype=torch.uint8, device=self._device) - else: - tp = fp = None - - return tp, fp, P, pred_scores, pred_labels - @reinit__is_reduced def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]) -> None: r"""Metric update function using prediction and target. @@ -346,62 +264,86 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor ========= ================= ================================================= """ self._check_matching_input(output) - for y_pred, y in zip(*output): - tp, fp, ps, scores, pred_labels = self._do_matching(y_pred, y) - if tp is not None: - self._tps.append(tp.to(device=self._device, dtype=torch.uint8)) - self._fps.append(fp.to(device=self._device, dtype=torch.uint8)) - self._scores.append(scores.to(self._device)) + for pred, target in zip(*output): + labels = target["labels"] + gt_boxes = target["bbox"] + gt_is_crowd = ( + target["iscrowd"].bool() if "iscrowd" in target else torch.zeros_like(labels, dtype=torch.bool) + ) + gt_ignore = ~self._match_area_range(gt_boxes) | gt_is_crowd + self._P += torch.bincount(labels[~gt_ignore], minlength=self._num_classes).to(device=self._device) + + best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True)[ + : self._max_detections_per_image + ] + print(best_detections_index) + pred_scores = pred["scores"][best_detections_index] + pred_labels = pred["labels"][best_detections_index] + pred_boxes = pred["bbox"][best_detections_index] + # Matching logic of object detection mAP, according to COCO reference implementation. + if len(pred_labels): + if not len(labels): + tp = torch.zeros((len(self._iou_thresholds), len(pred_labels)), dtype=torch.uint8, device=self._device) + self._tps.append(tp) + self._fps.append(~tp & self._match_area_range(pred_boxes)) + else: + ious = self.box_iou(pred_boxes, gt_boxes, cast(torch.BoolTensor, gt_is_crowd)) + category_no_match = labels.expand(len(pred_labels), -1) != pred_labels.view(-1, 1) + NO_MATCH = -3 + ious[category_no_match] = NO_MATCH + ious = ious.unsqueeze(-1).repeat((1, 1, len(self._iou_thresholds))) + ious[ious < self._iou_thresholds] = NO_MATCH + IGNORANCE = -2 + ious[:, gt_ignore] += IGNORANCE + for i in range(len(pred_labels)): + # Flip is done to give priority to the last item with maximal value, as torch.max selects the first one. + match_gts = ious[i].flip(0).max(0) + match_gts_indices = ious.size(1) -1 - match_gts.indices + for t in range(len(self._iou_thresholds)): + if match_gts.values[t] > NO_MATCH and not gt_is_crowd[match_gts_indices[t]]: + ious[:, match_gts_indices[t], t] = NO_MATCH + ious[i, match_gts_indices[t], t] = match_gts.values[t] + + max_ious = ious.max(1).values + self._tps.append((max_ious >= 0).T.to(dtype=torch.uint8, device=self._device)) + self._fps.append(((max_ious <= NO_MATCH).T & self._match_area_range(pred_boxes)).to(dtype=torch.uint8, device=self._device)) + + self._scores.append(pred_scores.to(self._device)) self._pred_labels.append(pred_labels.to(device=self._device)) - for cls in ps: - self._P[cls] += ps[cls] - classes = set(pred_labels.tolist()) | ps.keys() - if classes: - self._num_classes = max(max(classes) + 1, self._num_classes) - - def _compute(self) -> Union[torch.Tensor, float]: - num_classes = int(idist.all_reduce(self._num_classes or 0, "MAX")) - if num_classes < 1: - return 0.0 - - P = cast( - torch.Tensor, - idist.all_reduce(torch.tensor(list(map(self._P.__getitem__, range(num_classes))), device=self._device)), - ) - if P.sum() < 1: - return -1. + def _compute(self) -> torch.Tensor: + P = idist.all_reduce(self._P) pred_labels = _cat_and_agg_tensors(self._pred_labels, (), torch.long, self._device) - TP = _cat_and_agg_tensors(self._tps, (len(self.iou_thresholds),), torch.uint8, self._device) - FP = _cat_and_agg_tensors(self._fps, (len(self.iou_thresholds),), torch.uint8, self._device) + TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) + FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) scores = _cat_and_agg_tensors(self._scores, (), torch.double, self._device) average_precisions_recalls = -torch.ones( - (2, num_classes, len(self.iou_thresholds)), + (2, self._num_classes, len(self._iou_thresholds)), device=self._device, dtype=torch.double, ) - for cls in range(num_classes): + for cls in range(self._num_classes): if P[cls] == 0: continue cls_labels = pred_labels == cls if sum(cls_labels) == 0: - average_precisions_recalls[0, cls] = 0. + average_precisions_recalls[:, cls] = 0. continue recall, precision = self._compute_recall_and_precision(TP[..., cls_labels], FP[..., cls_labels], scores[cls_labels], P[cls]) - average_precision_for_cls_across_other_dims = self._compute_average_precision(recall, precision) - average_precisions_recalls[0, cls] = average_precision_for_cls_across_other_dims - + average_precision_for_cls_per_iou_threshold = self._compute_average_precision(recall, precision) + average_precisions_recalls[0, cls] = average_precision_for_cls_per_iou_threshold average_precisions_recalls[1, cls] = recall[..., -1] return average_precisions_recalls def compute(self) -> Tuple[float, float]: average_precisions_recalls = self._compute() + if (average_precisions_recalls == -1).all(): + return -1., -1. ap = average_precisions_recalls[0][average_precisions_recalls[0] > -1].mean().item() ar = average_precisions_recalls[1][average_precisions_recalls[1] > -1].mean().item() - return ap, ar diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index bdf88974f95..c776ac1396c 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -1,4 +1,5 @@ import sys +import itertools from collections import namedtuple from math import ceil from typing import Any, Dict, List, Tuple @@ -406,14 +407,14 @@ def coco_val2017_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, ] return [{"bbox": p[:, :4].double(), "scores": p[:, 4].double(), "labels": p[:, 5]} for p in pred], [ - {"bbox": g[:, :4].double(), "labels": g[:, 4], "iscrowd": g[:, 5]} for g in gt + {"bbox": g[:, :4].double(), "labels": g[:, 4].long(), "iscrowd": g[:, 5]} for g in gt ] def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]: torch.manual_seed(12) - targets = [] - preds = [] + targets: List[torch.Tensor] = [] + preds: List[torch.Tensor] = [] for _ in range(30): # Generate some ground truth boxes n_gt_box = torch.randint(50, (1,)).item() @@ -468,7 +469,7 @@ def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch preds.append(torch.cat((perturbed_gt_boxes, additional_pred_boxes), dim=0)) return [{"bbox": p[:, :4], "scores": p[:, 4], "labels": p[:, 5]} for p in preds], [ - {"bbox": g[:, :4], "labels": g[:, 4], "iscrowd": g[:, 5]} for g in targets + {"bbox": g[:, :4], "labels": g[:, 4].long(), "iscrowd": g[:, 5]} for g in targets ] @@ -567,14 +568,14 @@ def sample(request) -> Sample: 0, ), "labels": torch.zeros( - 0, + 0, dtype=torch.long ), } elif request.param[1] == "with_an_empty_gt": data[1][0] = { "bbox": torch.zeros(0, 4), "labels": torch.zeros( - 0, + 0, dtype=torch.long ), "iscrowd": torch.zeros( 0, @@ -587,7 +588,7 @@ def sample(request) -> Sample: 0, ), "labels": torch.zeros( - 0, + 0, dtype=torch.long ), } data[0][1] = { @@ -596,13 +597,13 @@ def sample(request) -> Sample: 0, ), "labels": torch.zeros( - 0, + 0, dtype=torch.long ), } data[1][0] = { "bbox": torch.zeros(0, 4), "labels": torch.zeros( - 0, + 0, dtype=torch.long ), "iscrowd": torch.zeros( 0, @@ -611,7 +612,7 @@ def sample(request) -> Sample: data[1][2] = { "bbox": torch.zeros(0, 4), "labels": torch.zeros( - 0, + 0, dtype=torch.long ), "iscrowd": torch.zeros( 0, @@ -643,47 +644,44 @@ def test_empty_data(): metric = ObjectDetectionAvgPrecisionRecall() metric.update( ( - [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}], - [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))}], + [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}], + [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}], ) ) assert len(metric._tps) == 0 assert len(metric._fps) == 0 - assert len(metric._P) == 0 - assert metric._num_classes == 0 - assert metric.compute() == 0 + assert metric.compute() == (-1, -1) metric.update( ( [ { "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), "scores": torch.ones((1,)), - "labels": torch.ones((1,)), + "labels": torch.ones((1,), dtype=torch.long), } ], - [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))}], + [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}], ) ) - assert metric.compute() == -1 + assert metric.compute() == (-1, -1) metric = ObjectDetectionAvgPrecisionRecall() metric.update( ( - [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0))}], + [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}], [ { "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), "iscrowd": torch.zeros((1,)), - "labels": torch.ones((1,)), + "labels": torch.ones((1,), dtype=torch.long), } ], ) ) assert len(metric._tps) == 0 assert len(metric._fps) == 0 - assert len(metric._P) == 1 and metric._P[1] == 1 - assert metric._num_classes == 2 - assert metric.compute() == 0 + assert metric._P[1] == 1 + assert metric.compute() == (0, 0) metric = ObjectDetectionAvgPrecisionRecall() pred = { @@ -691,12 +689,11 @@ def test_empty_data(): "scores": torch.tensor([0.9]), "labels": torch.tensor([5]), } - target = {"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0))} + target = {"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)} metric.update(([pred], [target])) assert len(metric._tps) == len(metric._fps) == 1 - assert len(metric._P) == 0 - assert metric._num_classes == 6 - assert metric.compute() == pycoco_mAP([pred], [target])[0] + pycoco_result = pycoco_mAP([pred], [target]) + assert metric.compute() == (pycoco_result[0], pycoco_result[8]) def test_no_torchvision(): @@ -736,7 +733,7 @@ def test_iou_thresholding(): } gt = {"bbox": torch.tensor([[0.0, 0.0, 50.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])} metric.update(([pred], [gt])) - assert (metric._tp[1][0] == torch.tensor([[True], [True], [True], [False]])).all() + assert (metric._tps[0] == torch.tensor([[True], [True], [True], [False]])).all() pred = { "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), @@ -745,34 +742,7 @@ def test_iou_thresholding(): } gt = {"bbox": torch.tensor([[100.0, 0.0, 200.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])} metric.update(([pred], [gt])) - assert (metric._tp[1][1] == torch.tensor([[True], [False], [False], [False]])).all() - - -def test__do_matching_output(sample): - metric = ObjectDetectionAvgPrecisionRecall() - - for pred, target in zip(*sample.data): - tps, fps, _, scores = metric._do_matching(pred, target) - - assert tps.keys() == fps.keys() == scores.keys() - - try: - cls = list(tps.keys()).pop() - except IndexError: # No prediction - pass - else: - assert tps[cls].dtype in (torch.bool, torch.uint8) - assert tps[cls].size(-1) == fps[cls].size(-1) == scores[cls].size(0) - - metric.update(([pred], [target])) - - for metric_tp_or_fp, new_tp_or_fp in [(metric._tp, tps), (metric._fp, fps)]: - try: - cls = (metric_tp_or_fp.keys() & new_tp_or_fp.keys()).pop() - except KeyError: - pass - else: - assert metric_tp_or_fp[cls][-1].shape[:-1] == new_tp_or_fp[cls].shape[:-1] + assert (metric._tps[1] == torch.tensor([[True], [False], [False], [False]])).all() class Dummy_mAP(ObjectDetectionAvgPrecisionRecall): @@ -807,16 +777,16 @@ def test_matching(): If there's equal confidence in two predictions, the dicision is first made for the one who comes earlier. 2. Each ground truth box is matched with at most one prediction. Crowd ground - truth is the exception. A prediction matched with a crowd gt would get ignored. - 3. Among many plausible ground truth boxes, a prediction is matched with the + truth is the exception. + 3. If a ground truth is crowd or out of area range, is set to be ignored. + 4. A prediction matched with a ignored gt would get ignored, in the sense that it becomes + neither tp nor fp. + 5. An unmatched prediction would get ignored if it's out of area range. So doesn't become fp due to rule 4. + 6. Among many plausible ground truth boxes, a prediction is matched with the one which has the highest mutual IOU. If two ground truth boxes have the same IOU with a prediction, the later one is matched. - 4. A prediction is matched with an out-of-area-range ground truth box only if there's no - plausible within-area-range ground truth box. In that case the prediction would get ignored. - 5. An unmatched prediction would get ignored if it's out of area range. - 6. A non-crowd ground truth has priority over a crowd ground truth in getting - matched with a prediction in the sense that even if the crowd ground truth - has a higher IOU, the non-crowd one gets matched if its IOU is viable. + 7. Non-ignored ground truths are given priority over the ignored ones when matching with a prediction + even if their IOU is lower. """ metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.2]) @@ -831,29 +801,30 @@ def test_matching(): "labels": torch.tensor([1]), } metric.update(([pred], [gt])) - assert (metric._tp[1][0] == torch.tensor([[True, False]])).all() - assert (metric._fp[1][0] == torch.tensor([[False, True]])).all() - assert (metric._scores[1][0] == torch.tensor([[0.9, 0.8]])).all() + # Preds are sorted by their scores internally + assert (metric._tps[0] == torch.tensor([[True, False]])).all() + assert (metric._fps[0] == torch.tensor([[False, True]])).all() + assert (metric._scores[0] == torch.tensor([[0.9, 0.8]])).all() pred["scores"] = torch.tensor([0.9, 0.9]) metric.update(([pred], [gt])) - assert (metric._tp[1][1] == torch.tensor([[True, False]])).all() - assert (metric._fp[1][1] == torch.tensor([[False, True]])).all() + assert (metric._tps[1] == torch.tensor([[True, False]])).all() + assert (metric._fps[1] == torch.tensor([[False, True]])).all() gt["iscrowd"] = torch.tensor([1]) metric.update(([pred], [gt])) - assert (metric._tp[1][2] == torch.tensor([[False, False]])).all() - assert (metric._fp[1][2] == torch.tensor([[False, False]])).all() + assert (metric._tps[2] == torch.tensor([[False, False]])).all() + assert (metric._fps[2] == torch.tensor([[False, False]])).all() pred["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 100.0], [100.0, 0.0, 200.0, 100.0]]) gt["bbox"] = torch.tensor([[0.0, 0.0, 25.0, 50.0], [50.0, 0.0, 150.0, 100.0]]) gt["iscrowd"] = torch.zeros((2,)) gt["labels"] = torch.tensor([1, 1]) metric.update(([pred], [gt])) - assert (metric._tp[1][3] == torch.tensor([[True, False]])).all() - assert (metric._fp[1][3] == torch.tensor([[False, True]])).all() + assert (metric._tps[3] == torch.tensor([[True, False]])).all() + assert (metric._fps[3] == torch.tensor([[False, True]])).all() - metric.area_range = "small" + metric._area_range = "small" pred["bbox"] = torch.tensor( [[0.0, 0.0, 100.0, 10.0], [0.0, 0.0, 100.0, 10.0], [0.0, 0.0, 100.0, 11.0], [0.0, 0.0, 100.0, 10.0]] ) @@ -861,8 +832,8 @@ def test_matching(): pred["labels"] = torch.tensor([1, 1, 1, 1]) gt["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 11.0], [0.0, 0.0, 100.0, 5.0]]) metric.update(([pred], [gt])) - assert (metric._tp[1][4] == torch.tensor([[True, False, False, False]])).all() - assert (metric._fp[1][4] == torch.tensor([[False, False, False, True]])).all() + assert (metric._tps[4] == torch.tensor([[True, False, False, False]])).all() + assert (metric._fps[4] == torch.tensor([[False, False, False, True]])).all() def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score): @@ -924,25 +895,46 @@ def test__compute_recall_and_precision(): def test_compute(sample): device = idist.device() - metric_50_95 = ObjectDetectionAvgPrecisionRecall(device=device) - metric_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=device) - metric_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=device) - metric_S = ObjectDetectionAvgPrecisionRecall(device=device, area_range="small") - metric_M = ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium") - metric_L = ObjectDetectionAvgPrecisionRecall(device=device, area_range="large") - metrics = [metric_50_95, metric_50, metric_75, metric_S, metric_M, metric_L] + # AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L + # ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(device=device) + # ap_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=device) + # ap_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=device) + # ap_ar_S = ObjectDetectionAvgPrecisionRecall(device=device, area_range="small") + # ap_ar_M = ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium") + # ap_ar_L = ObjectDetectionAvgPrecisionRecall(device=device, area_range="large") + # ar_1 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image=1) + ar_10 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image=10) + + # metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] + metrics = [ar_10] for metric in metrics: metric.update(sample.data) - + + ignite_res = [metric.compute() for metric in metrics] - assert all([np.allclose(re, pycoco_res) for re, pycoco_res in zip(ignite_res, sample.mAP)]) - ignite_res_recompute = [metric.compute() for metric in metrics] assert all([r1 == r2 for r1, r2 in zip(ignite_res, ignite_res_recompute)]) + # AP_50_95, AR_100 = ignite_res[0] + # AP_50 = ignite_res[1][0] + # AP_75 = ignite_res[2][0] + # AP_S, AR_S = ignite_res[3] + # AP_M, AR_M = ignite_res[4] + # AP_L, AR_L = ignite_res[5] + # AR_1 = ignite_res[6][1]### + AR_10 = ignite_res[0][1] ### + # for r in [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L]: + for r in [AR_10]: + print(r) + print("----------") + for r in sample.mAP: + print(r) + # assert np.allclose([AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L], sample.mAP) + assert np.allclose(AR_10, sample.mAP[-5]) + def test_integration(sample): bs = 3 From e425e12e16e1b658765ce0cafdca829e87775b4d Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Tue, 16 Jul 2024 15:43:30 +0330 Subject: [PATCH 20/41] Fix a bug; Some improvements; Improve docs --- ignite/metrics/mean_average_precision.py | 155 +++++++++--------- ...ject_detection_average_precision_recall.py | 150 +++++++++-------- .../metrics/test_mean_average_precision.py | 4 +- .../vision/test_object_detection_map.py | 135 ++++++--------- 4 files changed, 206 insertions(+), 238 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index 35eb66e76ba..be167680703 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -6,9 +6,8 @@ import ignite.distributed as idist from ignite.distributed.utils import all_gather_tensors_with_shapes -from ignite.metrics.metric import reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced from ignite.metrics.precision import _BaseClassification -from ignite.metrics.metric import Metric from ignite.utils import to_onehot @@ -18,17 +17,10 @@ def __init__( rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, class_mean: Optional[Literal["micro", "macro", "weighted"]] = "macro", ) -> None: - r"""Base class for Average Precision & Recall in classification and detection tasks. + r"""Base class for Average Precision metric. This class contains the methods for setting up the thresholds and computing AP & AR. - # Average precision is computed by averaging precision over increasing levels of recall thresholds as following: - - - and possibly some additional dimensions in the detection task. ``class_mean`` determines how to take this mean. - In the detection tasks, it's possible to take the mean in other respects as well e.g. IoU threshold in an - object detection task. - Args: rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need @@ -40,7 +32,7 @@ def __init__( An 1-dimensional tensor of mean (taken across additional mean dimensions) average precision per class is returned. If there's no ground truth sample for a class, ``0`` is returned for that. - micro + 'micro' Precision is computed counting stats of classes/labels altogether. This option incorporates class in the very precision measurement. @@ -53,7 +45,7 @@ def __init__( For multiclass inputs, this is equivalent with mean average accuracy. - weighted + 'weighted' like macro but considers class/label imbalance. For multiclass input, it computes AP for each class then returns mean of them weighted by support of classes (number of actual samples in each class). For multilabel input, @@ -62,9 +54,6 @@ def __init__( 'macro' computes macro precision which is unweighted mean of AP computed across classes/labels. Default. - - Note: - Please note that classes with no ground truth are not considered into the mean in detection. """ if rec_thresholds is not None: self.rec_thresholds: Optional[torch.Tensor] = self._setup_thresholds(rec_thresholds, "rec_thresholds") @@ -75,7 +64,6 @@ def __init__( raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") self.class_mean = class_mean - def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor: if isinstance(thresholds, Sequence): thresholds = torch.tensor(thresholds) @@ -96,7 +84,7 @@ def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], th return cast(torch.Tensor, thresholds) def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: - """Measuring average precision which is the common operation among different settings of the metric. + """Measuring average precision. Args: recall: n-dimensional tensor whose last dimension represents confidence thresholds as much as #samples. @@ -104,7 +92,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens precision: like ``recall`` in the shape. Returns: - average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. + average_precision: (n-1)-dimensional tensor containing the average precisions. """ if self.rec_thresholds is not None: rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1)) @@ -119,16 +107,28 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens return torch.sum(recall_differential * precision, dim=-1) -def _cat_and_agg_tensors(tensors: List[torch.Tensor], tensor_shape_except_last_dim: Tuple[int], dtype: torch.dtype, device: Union[str, torch.device]) -> torch.Tensor: +def _cat_and_agg_tensors( + tensors: List[torch.Tensor], + tensor_shape_except_last_dim: Tuple[int], + dtype: torch.dtype, + device: Union[str, torch.device], +) -> torch.Tensor: + """ + Concatenate tensors in ``tensors`` at their last dimension and gather all tensors from across all processes. + All tensors should have the same shape (denoted by ``tensor_shape_except_last_dim``) except at their + last dimension. + """ num_preds = torch.tensor( [sum([tensor.shape[-1] for tensor in tensors]) if tensors else 0], device=device, ) num_preds = cast(torch.Tensor, idist.all_gather(num_preds)).tolist() - tensor = torch.cat(tensors, dim=-1) if tensors else torch.empty((*tensor_shape_except_last_dim, 0), dtype=dtype, device=device) - shape_across_ranks = [ - (*tensor_shape_except_last_dim, num_pred_in_rank) for num_pred_in_rank in num_preds - ] + tensor = ( + torch.cat(tensors, dim=-1) + if tensors + else torch.empty((*tensor_shape_except_last_dim, 0), dtype=dtype, device=device) + ) + shape_across_ranks = [(*tensor_shape_except_last_dim, num_pred_in_rank) for num_pred_in_rank in num_preds] return torch.cat( all_gather_tensors_with_shapes( tensor, @@ -139,9 +139,8 @@ def _cat_and_agg_tensors(tensors: List[torch.Tensor], tensor_shape_except_last_d class MeanAveragePrecision(_BaseClassification, _BaseAveragePrecision): - - _scores: List[torch.Tensor] - _P: List[torch.Tensor] + _y_pred: List[torch.Tensor] + _y_true: List[torch.Tensor] def __init__( self, @@ -162,10 +161,8 @@ def __init__( thresholds weighted by the change in recall, as if the area under precision-recall curve is being computed. Mean average precision is then computed by taking the mean of this average precision over different classes. - For detection tasks, user should use downstream metrics like - :class:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP`. For classification, all the binary, - multiclass and multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be - set to true. + All the binary, multiclass and multilabel data are supported. In the latter case, + ``is_multilabel`` should be set to true. `mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean`` determines how to take this mean. @@ -228,8 +225,8 @@ def reset(self) -> None: Reset method of the metric """ super().reset() - self._scores = [] - self._P = [] + self._y_pred = [] + self._y_true = [] def _check_binary_multilabel_cases(self, output: Sequence[torch.Tensor]) -> None: # Ignore the check in `_BaseClassification` since `y_pred` consists of probabilities here. @@ -246,7 +243,8 @@ def _check_type(self, output: Sequence[torch.Tensor]) -> None: warnings.warn("`y` should be of dtype long when entry type is multiclass", RuntimeWarning) def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tensor]: - """Prepares and returns scores and P tensor. Input and output shapes of the method is as follows. + """Prepares and returns ``y_pred`` and ``y`` tensors. Input and output shapes of the method is as follows. + ``C`` and ``L`` denote the number of classes and labels in multiclass and multilabel inputs respectively. ========== =========== ============ ``y_pred`` @@ -254,7 +252,7 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens Data type Input shape Output shape ========== =========== ============ Binary (N, ...) (1, N * ...) - Multilabel (N, C, ...) (C, N * ...) + Multilabel (N, L, ...) (L, N * ...) Multiclass (N, C, ...) (C, N * ...) ========== =========== ============ @@ -264,7 +262,7 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens Data type Input shape Output shape ========== =========== ============ Binary (N, ...) (1, N * ...) - Multilabel (N, C, ...) (C, N * ...) + Multilabel (N, L, ...) (L, N * ...) Multiclass (N, ...) (N * ...) ========== =========== ============ """ @@ -272,11 +270,11 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens if self._type == "multilabel": num_classes = y_pred.size(1) - scores = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) - P = torch.transpose(y, 1, 0).reshape(num_classes, -1) + yp = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) + yt = torch.transpose(y, 1, 0).reshape(num_classes, -1) elif self._type == "binary": - P = y.view(1, -1) - scores = y_pred.view(1, -1) + yp = y_pred.view(1, -1) + yt = y.view(1, -1) else: # Multiclass num_classes = y_pred.size(1) if y.max() + 1 > num_classes: @@ -284,10 +282,10 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens f"y_pred contains fewer classes than y. Number of classes in prediction is {num_classes}" f" and an element in y has invalid class = {y.max().item() + 1}." ) - P = y.view(-1) - scores = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) + yt = y.view(-1) + yp = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) - return scores, P + return yp, yt @reinit__is_reduced def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: @@ -302,47 +300,47 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: """ self._check_shape(output) self._check_type(output) - scores, P = self._prepare_output(output) - self._scores.append(scores.to(self._device)) - self._P.append(P.to(self._device, dtype=torch.uint8 if self._type != "multiclass" else torch.long)) + yp, yt = self._prepare_output(output) + self._y_pred.append(yp.to(self._device)) + self._y_true.append(yt.to(self._device, dtype=torch.uint8 if self._type != "multiclass" else torch.long)) def _compute_recall_and_precision( - self, TP: torch.Tensor, scores: torch.Tensor, P: torch.Tensor + self, y_true: torch.Tensor, y_pred: torch.Tensor, y_true_positive_count: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Measuring recall & precision. - Shape of function inputs and return values follow the table below. C is the number of classes, 1 for binary - data. N is the number of samples. Finally, \#unique scores represents number of unique scores in ``scores`` - which is actually the number of thresholds. + Shape of function inputs and return values follow the table below. + N is the number of samples. \#unique scores represents number of + unique scores in ``scores`` which is actually the number of thresholds. - =================== ======================================= - **Object** **Shape** - =================== ======================================= - TP (N,) - scores (N,) - P () (A single float) - recall (\#unique scores,) - precision (\#unique scores,) - =================== ======================================= + ===================== ======================================= + **Object** **Shape** + ===================== ======================================= + y_true (N,) + y_pred (N,) + y_true_positive_count () (A single float) + recall (\#unique scores,) + precision (\#unique scores,) + ===================== ======================================= Returns: `(recall, precision)` """ - indices = torch.argsort(scores, dim=-1, stable=True, descending=True) - tp_summation = TP[..., indices].cumsum(dim=-1).double() + indices = torch.argsort(y_pred, stable=True, descending=True) + tp_summation = y_true[indices].cumsum(dim=0).double() # Adopted from Scikit-learn's implementation unique_scores_indices = torch.nonzero( - scores.take_along_dim(indices).diff(append=(scores.max() + 1).unsqueeze(dim=0)), as_tuple=True + y_pred[indices].diff(append=(y_pred.max() + 1).unsqueeze(dim=0)), as_tuple=True )[0] tp_summation = tp_summation[..., unique_scores_indices] fp_summation = (unique_scores_indices + 1) - tp_summation - if P == 0: + if y_true_positive_count == 0: # To be aligned with Scikit-Learn recall = torch.ones_like(tp_summation, device=self._device, dtype=torch.float) else: - recall = tp_summation / P + recall = tp_summation / y_true_positive_count predicted_positive = tp_summation + fp_summation precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) @@ -356,35 +354,30 @@ def compute(self) -> Union[torch.Tensor, float]: raise RuntimeError("Metric could not be computed without any update method call") num_classes = self._num_classes - P = _cat_and_agg_tensors( - self._P, - (num_classes,) if self._type == "multiabel" else (), + y_true = _cat_and_agg_tensors( + self._y_true, + () if self._type == "multiclass" else (num_classes,), torch.long if self._type == "multiclass" else torch.uint8, - self._device + self._device, ) - scores = _cat_and_agg_tensors( - self._scores, - (num_classes,) if self._type != "binary" else (), - torch.double, - self._device - ) + y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), torch.double, self._device) if self._type == "multiclass": - P = to_onehot(P, num_classes=num_classes).T + y_true = to_onehot(y_true, num_classes=num_classes).T if self.class_mean == "micro": - P = P.reshape(1, -1) - scores = scores.view(1, -1) - P_count = P.sum(dim=-1) - average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double) - for cls in range(len(P_count)): - recall, precision = self._compute_recall_and_precision(P[cls], scores[cls], P_count[cls]) + y_true = y_true.reshape(1, -1) + y_pred = y_pred.view(1, -1) + y_true_positive_count = y_true.sum(dim=-1) + average_precisions = torch.zeros_like(y_true_positive_count, device=self._device, dtype=torch.double) + for cls in range(y_true_positive_count.size(0)): + recall, precision = self._compute_recall_and_precision(y_true[cls], y_pred[cls], y_true_positive_count[cls]) average_precisions[cls] = self._compute_average_precision(recall, precision) if self._type == "binary": return average_precisions.item() if self.class_mean is None: return average_precisions elif self.class_mean == "weighted": - return torch.sum(P_count * average_precisions) / P_count.sum() + return torch.sum(y_true_positive_count * average_precisions) / y_true_positive_count.sum() else: return average_precisions.mean() diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index e6abe776e8e..0c4a75223cd 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -1,12 +1,11 @@ -from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union, override +from typing import Callable, cast, Dict, List, Optional, override, Sequence, Tuple, Union import torch from typing_extensions import Literal -import ignite.distributed as idist from ignite.metrics.mean_average_precision import _BaseAveragePrecision, _cat_and_agg_tensors -from ignite.metrics.metric import Metric, reinit__is_reduced +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce def tensor_list_to_dict_list( @@ -41,12 +40,11 @@ def tensor_list_to_dict_list( class ObjectDetectionAvgPrecisionRecall(Metric, _BaseAveragePrecision): - _tps: List[torch.Tensor] _fps: List[torch.Tensor] _scores: List[torch.Tensor] - _pred_labels: List[torch.Tensor] - _P: torch.Tensor + _y_pred_labels: List[torch.Tensor] + _y_true_count: torch.Tensor _num_classes: int def __init__( @@ -54,30 +52,37 @@ def __init__( iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, num_classes: Optional[int] = 91, - max_detections_per_image: Optional[int] = 100, + max_detections_per_image_per_class: Optional[int] = 100, area_range: Optional[Literal["small", "medium", "large", "all"]] = "all", output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: - """Calculate the mean average precision & recall for evaluating an object detector. + r"""Calculate mean average precision & recall for evaluating an object detector in the COCO way. + - In average precision, the maximum precision across thresholds greater or equal a recall threshold is - considered as the summation operand; In other words, the precision peek across lower or equal + Average precision is computed by averaging precision over increasing levels of recall thresholds. + In COCO, the maximum precision across thresholds greater or equal a recall threshold is + considered as the average summation operand; In other words, the precision peek across lower or equal sensivity levels is used for a recall threshold: .. math:: \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) max(P_{k:}) + Average recall is the detector's maximum recall, considering all matched detections as TP, + averaged over classes. + Args: iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision & recall. Values should be between 0 and 1. If not given, COCO's default (.5, .55, ..., .95) would be used. rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, COCO's default (.0, .01, .02, ..., 1.) would be used. - max_detections_per_image: Max number of detections in each image to consider for evaluation. The most - confident ones are selected. + num_classes: number of categories. Default is 91, that of the COCO. + area_range: area range which only objects therein are considered in evaluation. By default, 'all'. + max_detections_per_image_per_class: maximum number of detections per class in each image to consider + for evaluation. The most confident ones are selected. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s - ``process_function``'s output into the form expected by the metric. An already - provided example is :func:`~ignite.metrics.vision.object_detection_map.tensor_list_to_dict_list` + ``process_function``'s output into the form expected by the metric. An already provided example + is :func:`~ignite.metrics.vision.object_detection_average_precision_recall.tensor_list_to_dict_list` which accepts `y_pred` and `y` as lists of tensors and transforms them to the expected format. Default is the identity function. device: specifies which device updates are accumulated on. Setting the @@ -108,7 +113,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo self._num_classes = num_classes self._area_range = area_range - self._max_detections_per_image = max_detections_per_image + self._max_detections_per_image_per_class = max_detections_per_image_per_class super(ObjectDetectionAvgPrecisionRecall, self).__init__( output_transform=output_transform, @@ -124,8 +129,8 @@ def reset(self) -> None: self._tps = [] self._fps = [] self._scores = [] - self._pred_labels = [] - self._P = torch.zeros((self._num_classes,), device=self._device) + self._y_pred_labels = [] + self._y_true_count = torch.zeros((self._num_classes,), device=self._device) def _match_area_range(self, bboxes: torch.Tensor) -> torch.Tensor: from torchvision.ops.boxes import box_area @@ -169,7 +174,7 @@ def _check_matching_input( ) def _compute_recall_and_precision( - self, TP: torch.Tensor, FP: torch.Tensor, scores: torch.Tensor, P: torch.Tensor + self, TP: torch.Tensor, FP: torch.Tensor, scores: torch.Tensor, y_true_count: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Measuring recall & precision @@ -187,7 +192,7 @@ def _compute_recall_and_precision( ============== ====================== TP and FP (..., N\ :sub:`pred`) scores (N\ :sub:`pred`,) - P () (A single float, + y_true_count () (A single float, greater than zero) recall (..., \#unique scores) precision (..., \#unique scores) @@ -202,7 +207,7 @@ def _compute_recall_and_precision( fp = FP[..., indices] fp_summation = fp.cumsum(dim=-1).double() - recall = tp_summation / P + recall = tp_summation / y_true_count predicted_positive = tp_summation + fp_summation precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) @@ -224,7 +229,11 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens """ precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) - rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) if recall.size(-1) != 0 else torch.LongTensor([], device=self._device) + rec_thresh_indices = ( + torch.searchsorted(recall, rec_thresholds) + if recall.size(-1) != 0 + else torch.LongTensor([], device=self._device) + ) precision_integrand = precision_integrand.take_along_dim( rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 ).where(rec_thresh_indices != recall.size(-1), 0) @@ -232,7 +241,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens @reinit__is_reduced def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]) -> None: - r"""Metric update function using prediction and target. + r"""Metric update method using prediction and target. Args: output: a tuple, (y_pred, y), of two same-length lists, each one containing @@ -248,7 +257,8 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor 'bbox' (N\ :sub:`det`, 4) Bounding boxes of form (x1, y1, x2, y2) containing top left and bottom right coordinates. 'scores' (N\ :sub:`det`,) Confidence score of detections. - 'labels' (N\ :sub:`det`,) Predicted category number of detections. + 'labels' (N\ :sub:`det`,) Predicted category number of detections in + `torch.long` dtype. ======== ================== ================================================= ========= ================= ================================================= @@ -258,7 +268,8 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor ========= ================= ================================================= 'bbox' (N\ :sub:`gt`, 4) Bounding boxes of form (x1, y1, x2, y2) containing top left and bottom right coordinates. - 'labels' (N\ :sub:`gt`,) Category number of ground truths. + 'labels' (N\ :sub:`gt`,) Category number of ground truths in `torch.long` + dtype. 'iscrowd' (N\ :sub:`gt`,) Whether ground truth boxes are crowd ones or not. This key is optional. ========= ================= ================================================= @@ -271,19 +282,29 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor target["iscrowd"].bool() if "iscrowd" in target else torch.zeros_like(labels, dtype=torch.bool) ) gt_ignore = ~self._match_area_range(gt_boxes) | gt_is_crowd - self._P += torch.bincount(labels[~gt_ignore], minlength=self._num_classes).to(device=self._device) - - best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True)[ - : self._max_detections_per_image - ] - print(best_detections_index) - pred_scores = pred["scores"][best_detections_index] - pred_labels = pred["labels"][best_detections_index] - pred_boxes = pred["bbox"][best_detections_index] + self._y_true_count += torch.bincount(labels[~gt_ignore], minlength=self._num_classes).to( + device=self._device + ) + # Matching logic of object detection mAP, according to COCO reference implementation. - if len(pred_labels): + if len(pred["labels"]): + best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True) + max_best_detections_index = torch.cat( + [ + best_detections_index[pred["labels"][best_detections_index] == c][ + : self._max_detections_per_image_per_class + ] + for c in range(self._num_classes) + ] + ) + pred_boxes = pred["bbox"][max_best_detections_index] + pred_labels = pred["labels"][max_best_detections_index] if not len(labels): - tp = torch.zeros((len(self._iou_thresholds), len(pred_labels)), dtype=torch.uint8, device=self._device) + tp = torch.zeros( + (len(self._iou_thresholds), len(max_best_detections_index)), + dtype=torch.uint8, + device=self._device, + ) self._tps.append(tp) self._fps.append(~tp & self._match_area_range(pred_boxes)) else: @@ -296,9 +317,10 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor IGNORANCE = -2 ious[:, gt_ignore] += IGNORANCE for i in range(len(pred_labels)): - # Flip is done to give priority to the last item with maximal value, as torch.max selects the first one. + # Flip is done to give priority to the last item with maximal value, + # as torch.max selects the first one. match_gts = ious[i].flip(0).max(0) - match_gts_indices = ious.size(1) -1 - match_gts.indices + match_gts_indices = ious.size(1) - 1 - match_gts.indices for t in range(len(self._iou_thresholds)): if match_gts.values[t] > NO_MATCH and not gt_is_crowd[match_gts_indices[t]]: ious[:, match_gts_indices[t], t] = NO_MATCH @@ -306,14 +328,18 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor max_ious = ious.max(1).values self._tps.append((max_ious >= 0).T.to(dtype=torch.uint8, device=self._device)) - self._fps.append(((max_ious <= NO_MATCH).T & self._match_area_range(pred_boxes)).to(dtype=torch.uint8, device=self._device)) - - self._scores.append(pred_scores.to(self._device)) - self._pred_labels.append(pred_labels.to(device=self._device)) + self._fps.append( + ((max_ious <= NO_MATCH).T & self._match_area_range(pred_boxes)).to( + dtype=torch.uint8, device=self._device + ) + ) + + self._scores.append(pred["scores"][max_best_detections_index].to(self._device)) + self._y_pred_labels.append(pred_labels.to(device=self._device)) + @sync_all_reduce("_y_true_count") def _compute(self) -> torch.Tensor: - P = idist.all_reduce(self._P) - pred_labels = _cat_and_agg_tensors(self._pred_labels, (), torch.long, self._device) + pred_labels = _cat_and_agg_tensors(self._y_pred_labels, (), torch.long, self._device) TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) scores = _cat_and_agg_tensors(self._scores, (), torch.double, self._device) @@ -324,46 +350,26 @@ def _compute(self) -> torch.Tensor: dtype=torch.double, ) for cls in range(self._num_classes): - if P[cls] == 0: + if self._y_true_count[cls] == 0: continue - + cls_labels = pred_labels == cls if sum(cls_labels) == 0: - average_precisions_recalls[:, cls] = 0. + average_precisions_recalls[:, cls] = 0.0 continue - - recall, precision = self._compute_recall_and_precision(TP[..., cls_labels], FP[..., cls_labels], scores[cls_labels], P[cls]) + + recall, precision = self._compute_recall_and_precision( + TP[..., cls_labels], FP[..., cls_labels], scores[cls_labels], self._y_true_count[cls] + ) average_precision_for_cls_per_iou_threshold = self._compute_average_precision(recall, precision) average_precisions_recalls[0, cls] = average_precision_for_cls_per_iou_threshold average_precisions_recalls[1, cls] = recall[..., -1] return average_precisions_recalls - + def compute(self) -> Tuple[float, float]: average_precisions_recalls = self._compute() if (average_precisions_recalls == -1).all(): - return -1., -1. + return -1.0, -1.0 ap = average_precisions_recalls[0][average_precisions_recalls[0] > -1].mean().item() ar = average_precisions_recalls[1][average_precisions_recalls[1] > -1].mean().item() return ap, ar - - -# class ObjDetCommonAPandAR(Metric): -# """ -# Computes following common variants of average precision (AP) and recall (AR). - -# ============== ====================== -# Metric variant Description -# ============== ====================== -# AP@.5...95 (..., N\ :sub:`pred`) -# AP@.5 (N\ :sub:`pred`,) -# AP@.75 () (A single float, -# AP-S greater than zero) -# AP-M -# AP-L (..., \#unique scores) -# AR-S -# AR-M -# AR-L (..., \#unique scores) -# ============== ====================== -# """ -# def __init__(self, output_transform: Callable = lambda x:x, device: Union[str, torch.device] = torch.device("cpu")): -# super().__init__(output_transform, device) \ No newline at end of file diff --git a/tests/ignite/metrics/test_mean_average_precision.py b/tests/ignite/metrics/test_mean_average_precision.py index d66ac5bcd89..f24f33abb9d 100644 --- a/tests/ignite/metrics/test_mean_average_precision.py +++ b/tests/ignite/metrics/test_mean_average_precision.py @@ -61,9 +61,9 @@ def test__prepare_output(): def test_update(): metric = MeanAveragePrecision() - assert len(metric._scores) == len(metric._P) == 0 + assert len(metric._y_pred) == len(metric._y_true) == 0 metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4)).bool())) - assert len(metric._scores) == len(metric._P) == 1 + assert len(metric._y_pred) == len(metric._y_true) == 1 def test__compute_recall_and_precision(): diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index c776ac1396c..2e5784731c5 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -1,8 +1,7 @@ import sys -import itertools from collections import namedtuple from math import ceil -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple from unittest.mock import patch import numpy as np @@ -567,16 +566,12 @@ def sample(request) -> Sample: "scores": torch.zeros( 0, ), - "labels": torch.zeros( - 0, dtype=torch.long - ), + "labels": torch.zeros(0, dtype=torch.long), } elif request.param[1] == "with_an_empty_gt": data[1][0] = { "bbox": torch.zeros(0, 4), - "labels": torch.zeros( - 0, dtype=torch.long - ), + "labels": torch.zeros(0, dtype=torch.long), "iscrowd": torch.zeros( 0, ), @@ -587,33 +582,25 @@ def sample(request) -> Sample: "scores": torch.zeros( 0, ), - "labels": torch.zeros( - 0, dtype=torch.long - ), + "labels": torch.zeros(0, dtype=torch.long), } data[0][1] = { "bbox": torch.zeros(0, 4), "scores": torch.zeros( 0, ), - "labels": torch.zeros( - 0, dtype=torch.long - ), + "labels": torch.zeros(0, dtype=torch.long), } data[1][0] = { "bbox": torch.zeros(0, 4), - "labels": torch.zeros( - 0, dtype=torch.long - ), + "labels": torch.zeros(0, dtype=torch.long), "iscrowd": torch.zeros( 0, ), } data[1][2] = { "bbox": torch.zeros(0, 4), - "labels": torch.zeros( - 0, dtype=torch.long - ), + "labels": torch.zeros(0, dtype=torch.long), "iscrowd": torch.zeros( 0, ), @@ -745,31 +732,6 @@ def test_iou_thresholding(): assert (metric._tps[1] == torch.tensor([[True], [False], [False], [False]])).all() -class Dummy_mAP(ObjectDetectionAvgPrecisionRecall): - def _do_matching(self, tup1: Tuple, tup2: Tuple): - tp, fp = tup1 - p, score = tup2 - return tp, fp, p, score - - def _check_matching_input(self, output: Any): - pass - - -def test_update(): - metric = Dummy_mAP() - assert len(metric._tp) == len(metric._fp) == len(metric._scores) == len(metric._P) == metric._num_classes == 0 - - metric.update( - ([({1: torch.tensor([True])}, {1: torch.tensor([False])})], [({1: 1, 2: 1}, {1: torch.tensor([0.8])})]) - ) - assert len(metric._tp[1]) == len(metric._fp[1]) == len(metric._scores[1]) == 1 - assert len(metric._P) == 2 and metric._P[2] == 1 - assert metric._num_classes == 3 - - metric.update(([({}, {})], [({2: 2}, {})])) - assert metric._P[2] == 3 - - def test_matching(): """ PyCOCO matching rules: @@ -835,6 +797,12 @@ def test_matching(): assert (metric._tps[4] == torch.tensor([[True, False, False, False]])).all() assert (metric._fps[4] == torch.tensor([[False, False, False, True]])).all() + pred["scores"] = torch.tensor([0.9, 1.0, 0.9, 0.9]) + metric._max_detections_per_image = 1 + metric.update(([pred], [gt])) + assert (metric._tps[5] == torch.tensor([[True]])).all() + assert (metric._fps[5] == torch.tensor([[False]])).all() + def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score): y_true = y_true == 1 @@ -897,43 +865,33 @@ def test_compute(sample): device = idist.device() # AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L - # ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(device=device) - # ap_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=device) - # ap_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=device) - # ap_ar_S = ObjectDetectionAvgPrecisionRecall(device=device, area_range="small") - # ap_ar_M = ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium") - # ap_ar_L = ObjectDetectionAvgPrecisionRecall(device=device, area_range="large") - # ar_1 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image=1) + ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(device=device) + ap_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=device) + ap_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=device) + ap_ar_S = ObjectDetectionAvgPrecisionRecall(device=device, area_range="small") + ap_ar_M = ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium") + ap_ar_L = ObjectDetectionAvgPrecisionRecall(device=device, area_range="large") + ar_1 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image=1) ar_10 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image=10) - - # metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] - metrics = [ar_10] + metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] for metric in metrics: metric.update(sample.data) - - + ignite_res = [metric.compute() for metric in metrics] ignite_res_recompute = [metric.compute() for metric in metrics] assert all([r1 == r2 for r1, r2 in zip(ignite_res, ignite_res_recompute)]) - # AP_50_95, AR_100 = ignite_res[0] - # AP_50 = ignite_res[1][0] - # AP_75 = ignite_res[2][0] - # AP_S, AR_S = ignite_res[3] - # AP_M, AR_M = ignite_res[4] - # AP_L, AR_L = ignite_res[5] - # AR_1 = ignite_res[6][1]### - AR_10 = ignite_res[0][1] ### - # for r in [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L]: - for r in [AR_10]: - print(r) - print("----------") - for r in sample.mAP: - print(r) - # assert np.allclose([AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L], sample.mAP) - assert np.allclose(AR_10, sample.mAP[-5]) + AP_50_95, AR_100 = ignite_res[0] + AP_50 = ignite_res[1][0] + AP_75 = ignite_res[2][0] + AP_S, AR_S = ignite_res[3] + AP_M, AR_M = ignite_res[4] + AP_L, AR_L = ignite_res[5] + AR_1 = ignite_res[6][1] + AR_10 = ignite_res[7][1] + assert np.allclose([AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L], sample.mAP) def test_integration(sample): @@ -953,7 +911,7 @@ def update(engine, i): n_iter = ceil(sample.length / bs) engine.run(range(n_iter), max_epochs=1) - res_50_95 = engine.state.metrics["mAP[50-95]"] + res_50_95 = engine.state.metrics["mAP[50-95]"][0] pycoco_res_50_95 = sample.mAP[0] assert np.allclose(res_50_95, pycoco_res_50_95) @@ -1008,14 +966,17 @@ def test_distrib_update_compute(distributed, sample): device = idist.device() metric_device = "cpu" if device.type == "xla" else device - metric_50_95 = ObjectDetectionAvgPrecisionRecall(device=metric_device) - metric_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=metric_device) - metric_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=metric_device) - metric_S = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="small") - metric_M = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="medium") - metric_L = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="large") + # AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L + ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(device=metric_device) + ap_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=metric_device) + ap_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=metric_device) + ap_ar_S = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="small") + ap_ar_M = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="medium") + ap_ar_L = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="large") + ar_1 = ObjectDetectionAvgPrecisionRecall(device=metric_device, max_detections_per_image=1) + ar_10 = ObjectDetectionAvgPrecisionRecall(device=metric_device, max_detections_per_image=10) - metrics = [metric_50_95, metric_50, metric_75, metric_S, metric_M, metric_L] + metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] y_pred_rank = sample.data[0][rank_samples_range] y_rank = sample.data[1][rank_samples_range] @@ -1023,7 +984,15 @@ def test_distrib_update_compute(distributed, sample): metric.update((y_pred_rank, y_rank)) ignite_res = [metric.compute() for metric in metrics] - assert all([np.allclose(re, pycoco_res) for re, pycoco_res in zip(ignite_res, sample.mAP)]) - ignite_res_recompute = [metric.compute() for metric in metrics] assert all([r1 == r2 for r1, r2 in zip(ignite_res, ignite_res_recompute)]) + + AP_50_95, AR_100 = ignite_res[0] + AP_50 = ignite_res[1][0] + AP_75 = ignite_res[2][0] + AP_S, AR_S = ignite_res[3] + AP_M, AR_M = ignite_res[4] + AP_L, AR_L = ignite_res[5] + AR_1 = ignite_res[6][1] + AR_10 = ignite_res[7][1] + assert np.allclose([AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L], sample.mAP) From bb15f0fe0828345ae9fee078a9f7bde63c222104 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Tue, 16 Jul 2024 16:08:46 +0330 Subject: [PATCH 21/41] Fix metrics.rst --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index c01e9818797..2e0a79eefda 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -338,7 +338,7 @@ Complete list of metrics metric.Metric metrics_lambda.MetricsLambda MultiLabelConfusionMatrix - ObjectDetectionMAP + ObjectDetectionAvgPrecisionRecall precision.Precision PSNR recall.Recall From 6fcc97f774b6f90bcf4a390d184eab837793134a Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Tue, 16 Jul 2024 16:19:53 +0330 Subject: [PATCH 22/41] Remove @override which is for 3.12 --- .../vision/object_detection_average_precision_recall.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 0c4a75223cd..700abae36f0 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -1,4 +1,4 @@ -from typing import Callable, cast, Dict, List, Optional, override, Sequence, Tuple, Union +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch from typing_extensions import Literal @@ -213,7 +213,6 @@ def _compute_recall_and_precision( return recall, precision - @override def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: """Measuring average precision. This method is overriden since :math:`1/#recall_thresholds` is used instead of :math:`r_k - r_{k-1}` From 120c755c80afbf8c0717a850663d7b0de11f07e1 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Tue, 16 Jul 2024 17:09:06 +0330 Subject: [PATCH 23/41] Fix mypy issues --- ignite/metrics/mean_average_precision.py | 10 +++++----- .../object_detection_average_precision_recall.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index be167680703..7a813f15470 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -81,7 +81,7 @@ def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], th if min(thresholds) < 0 or max(thresholds) > 1: raise ValueError(f"{threshold_type} values should be between 0 and 1, given {thresholds}") - return cast(torch.Tensor, thresholds) + return thresholds def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: """Measuring average precision. @@ -102,7 +102,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens ).where(rec_thresh_indices != recall.size(-1), 0) recall = rec_thresholds recall_differential = recall.diff( - dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=self._device, dtype=torch.double) + dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=torch.double) ) return torch.sum(recall_differential * precision, dim=-1) @@ -122,13 +122,13 @@ def _cat_and_agg_tensors( [sum([tensor.shape[-1] for tensor in tensors]) if tensors else 0], device=device, ) - num_preds = cast(torch.Tensor, idist.all_gather(num_preds)).tolist() + all_num_preds = cast(torch.Tensor, idist.all_gather(num_preds)).tolist() tensor = ( torch.cat(tensors, dim=-1) if tensors else torch.empty((*tensor_shape_except_last_dim, 0), dtype=dtype, device=device) ) - shape_across_ranks = [(*tensor_shape_except_last_dim, num_pred_in_rank) for num_pred_in_rank in num_preds] + shape_across_ranks = [(*tensor_shape_except_last_dim, num_pred_in_rank) for num_pred_in_rank in all_num_preds] return torch.cat( all_gather_tensors_with_shapes( tensor, @@ -356,7 +356,7 @@ def compute(self) -> Union[torch.Tensor, float]: y_true = _cat_and_agg_tensors( self._y_true, - () if self._type == "multiclass" else (num_classes,), + cast(Tuple[int], ()) if self._type == "multiclass" else (num_classes,), torch.long if self._type == "multiclass" else torch.uint8, self._device, ) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 700abae36f0..3856825d331 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -51,9 +51,9 @@ def __init__( self, iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - num_classes: Optional[int] = 91, - max_detections_per_image_per_class: Optional[int] = 100, - area_range: Optional[Literal["small", "medium", "large", "all"]] = "all", + num_classes: int = 91, + max_detections_per_image_per_class: int = 100, + area_range: Literal["small", "medium", "large", "all"] = "all", output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ) -> None: @@ -338,10 +338,10 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor @sync_all_reduce("_y_true_count") def _compute(self) -> torch.Tensor: - pred_labels = _cat_and_agg_tensors(self._y_pred_labels, (), torch.long, self._device) + pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.long, self._device) TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) - scores = _cat_and_agg_tensors(self._scores, (), torch.double, self._device) + scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), torch.double, self._device) average_precisions_recalls = -torch.ones( (2, self._num_classes, len(self._iou_thresholds)), From 7c26d0844d5144e5750107e8cb09c35115aa23b6 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Tue, 16 Jul 2024 17:22:33 +0330 Subject: [PATCH 24/41] Fix two tests --- tests/ignite/metrics/vision/test_object_detection_map.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 2e5784731c5..4992e91ac17 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -667,7 +667,7 @@ def test_empty_data(): ) assert len(metric._tps) == 0 assert len(metric._fps) == 0 - assert metric._P[1] == 1 + assert metric._y_true_count[1] == 1 assert metric.compute() == (0, 0) metric = ObjectDetectionAvgPrecisionRecall() @@ -798,7 +798,7 @@ def test_matching(): assert (metric._fps[4] == torch.tensor([[False, False, False, True]])).all() pred["scores"] = torch.tensor([0.9, 1.0, 0.9, 0.9]) - metric._max_detections_per_image = 1 + metric._max_detections_per_image_per_class = 1 metric.update(([pred], [gt])) assert (metric._tps[5] == torch.tensor([[True]])).all() assert (metric._fps[5] == torch.tensor([[False]])).all() From c3c4a82ba837d31a9e4d2b9a95acb0158f104d3e Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Tue, 16 Jul 2024 17:44:51 +0330 Subject: [PATCH 25/41] Fix a typo in tests --- tests/ignite/metrics/vision/test_object_detection_map.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 4992e91ac17..c6ca4d11250 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -871,8 +871,8 @@ def test_compute(sample): ap_ar_S = ObjectDetectionAvgPrecisionRecall(device=device, area_range="small") ap_ar_M = ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium") ap_ar_L = ObjectDetectionAvgPrecisionRecall(device=device, area_range="large") - ar_1 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image=1) - ar_10 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image=10) + ar_1 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image_per_class=1) + ar_10 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image_per_class=10) metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] for metric in metrics: @@ -973,8 +973,8 @@ def test_distrib_update_compute(distributed, sample): ap_ar_S = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="small") ap_ar_M = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="medium") ap_ar_L = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="large") - ar_1 = ObjectDetectionAvgPrecisionRecall(device=metric_device, max_detections_per_image=1) - ar_10 = ObjectDetectionAvgPrecisionRecall(device=metric_device, max_detections_per_image=10) + ar_1 = ObjectDetectionAvgPrecisionRecall(device=metric_device, max_detections_per_image_per_class=1) + ar_10 = ObjectDetectionAvgPrecisionRecall(device=metric_device, max_detections_per_image_per_class=10) metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] From 240593763c2acb4191ac2025ebf5ddbc70f4b4be Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Tue, 16 Jul 2024 22:30:11 +0330 Subject: [PATCH 26/41] Fix dist tests --- ignite/distributed/utils.py | 1 + tests/ignite/distributed/utils/__init__.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index e509a08c9de..6c1357203c6 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -356,6 +356,7 @@ def all_reduce( def all_gather_tensors_with_shapes( tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None ) -> List[torch.Tensor]: + """Gather tensors with different shapes but with the same number of dimensions from across processes.""" if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 95e17c7ab57..3d60b76a5bc 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -195,7 +195,7 @@ def _test_distrib_all_gather(device): true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1) assert (res == true_res).all() - ts = [torch.randn(tuple(torch.randint(1, 10, (3,))), device=device) for _ in range(ws)] + ts = [torch.randn((r + 1, r + 2, r + 3), device=device) for r in range(ws)] ts_gathered = all_gather_tensors_with_shapes(ts[rank], [list(t.shape) for t in ts]) for t, t_gathered in zip(ts, ts_gathered): assert (t == t_gathered).all() @@ -226,7 +226,7 @@ def _test_distrib_all_gather(device): def _test_distrib_all_gather_group(device): if idist.get_world_size() > 1: - ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1] + ranks = list(range(1, idist.get_world_size())) rank = idist.get_rank() bnd = idist.backend() @@ -253,8 +253,9 @@ def _test_distrib_all_gather_group(device): else: assert res == t - ts = [torch.randn(tuple(torch.randint(1, 10, (3,))), device=device) for _ in range(idist.get_world_size())] - ts_gathered = all_gather_tensors_with_shapes(ts[rank], [list(t.shape) for t in ts], ranks) + ts = [torch.randn((i + 1, i + 2, i + 3), device=device) for i in range(idist.get_world_size())] + shapes = [list(t.shape) for r, t in enumerate(ts) if r in ranks] + ts_gathered = all_gather_tensors_with_shapes(ts[rank], shapes, ranks) if rank in ranks: for i, r in enumerate(ranks): assert (ts[r] == ts_gathered[i]).all() From 356f618fbd60bf0f12fe0b85410ea22e3aa33f94 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 01:28:08 +0330 Subject: [PATCH 27/41] Add common obj. det. metrics --- docs/source/metrics.rst | 1 + ignite/metrics/__init__.py | 6 +- ignite/metrics/mean_average_precision.py | 2 +- ignite/metrics/metric_group.py | 2 +- ...ject_detection_average_precision_recall.py | 91 ++++++++++++++++++- .../vision/test_object_detection_map.py | 50 +++++++++- 6 files changed, 145 insertions(+), 7 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 8fb3f641c5e..af60dd41be2 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -341,6 +341,7 @@ Complete list of metrics MultiLabelConfusionMatrix MutualInformation ObjectDetectionAvgPrecisionRecall + CommonObjDetectionMetrics precision.Precision PSNR recall.Recall diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 943769e09aa..beac0b44e41 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -38,7 +38,10 @@ from ignite.metrics.running_average import RunningAverage from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy -from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionAvgPrecisionRecall +from ignite.metrics.vision.object_detection_average_precision_recall import ( + CommonObjDetectionMetrics, + ObjectDetectionAvgPrecisionRecall, +) __all__ = [ "Metric", @@ -90,4 +93,5 @@ "ROC_AUC", "MeanAveragePrecision", "ObjectDetectionAvgPrecisionRecall", + "CommonObjDetectionMetrics", ] diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index 7a813f15470..71896512a73 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -66,7 +66,7 @@ def __init__( def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor: if isinstance(thresholds, Sequence): - thresholds = torch.tensor(thresholds) + thresholds = torch.tensor(thresholds, dtype=torch.double) if isinstance(thresholds, torch.Tensor): if thresholds.ndim != 1: diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index 58a52f658ae..72398d464d4 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -36,7 +36,7 @@ class MetricGroup(Metric): state.metrics["eval_metrics"] """ - _state_dict_all_req_keys = ("metrics",) + _state_dict_all_req_keys: tuple[str, ...] = ("metrics",) def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x): self.metrics = metrics diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 3856825d331..0ae07e924a5 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -3,8 +3,9 @@ import torch from typing_extensions import Literal -from ignite.metrics.mean_average_precision import _BaseAveragePrecision, _cat_and_agg_tensors +from ignite.metrics import MetricGroup +from ignite.metrics.mean_average_precision import _BaseAveragePrecision, _cat_and_agg_tensors from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce @@ -372,3 +373,91 @@ def compute(self) -> Tuple[float, float]: ap = average_precisions_recalls[0][average_precisions_recalls[0] > -1].mean().item() ar = average_precisions_recalls[1][average_precisions_recalls[1] > -1].mean().item() return ap, ar + + +class CommonObjDetectionMetrics(MetricGroup): + """ + Common Object detection metrics. Included metrics are as follows: + + =============== ========================================== + **Metric name** **Description** + =============== ========================================== + AP@50..95 Average precision averaged over + .50 to.95 IOU thresholds + AR-100 Average recall with maximum 100 detections + AP@50 Average precision with IOU threshold=.50 + AP@75 Average precision with IOU threshold=.75 + AP-S Average precision over small objects + (< 32px * 32px) + AR-S Average recall over small objects + AP-M Average precision over medium objects + (S < . < 96px * 96px) + AR-M Average recall over medium objects + AP-L Average precision over large objects + (M < . < 1e5px * 1e5px) + AR-L Average recall over large objects + greater than zero) + AR-1 Average recall with maximum 1 detection + AR-10 Average recall with maximum 10 detections + =============== ========================================== + + """ + + _state_dict_all_req_keys = ("metrics", "ap_50_95") + + ap_50_95: ObjectDetectionAvgPrecisionRecall + + def __init__( + self, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + ): + self.ap_50_95 = ObjectDetectionAvgPrecisionRecall(device=device) + + super().__init__( + { + "S": ObjectDetectionAvgPrecisionRecall(device=device, area_range="small"), + "M": ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium"), + "L": ObjectDetectionAvgPrecisionRecall(device=device, area_range="large"), + "1": ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image_per_class=1), + "10": ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image_per_class=10), + }, + output_transform, + ) + + def reset(self) -> None: + super().reset() + self.ap_50_95.reset() + + def update(self, output: Sequence[torch.Tensor]) -> None: + super().update(output) + self.ap_50_95.update(output) + + def compute(self) -> Dict[str, float]: + average_precisions_recalls = self.ap_50_95._compute() + + average_precisions_50 = average_precisions_recalls[0, :, 0] + average_precisions_75 = average_precisions_recalls[0, :, 5] + if (average_precisions_50 == -1).all(): + AP_50 = AP_75 = AP_50_95 = AR_100 = -1.0 + else: + AP_50 = average_precisions_50[average_precisions_50 > -1].mean().item() + AP_75 = average_precisions_75[average_precisions_75 > -1].mean().item() + AP_50_95 = average_precisions_recalls[0][average_precisions_recalls[0] > -1].mean().item() + AR_100 = average_precisions_recalls[1][average_precisions_recalls[1] > -1].mean().item() + + result = super().compute() + return { + "AP@50..95": AP_50_95, + "AR-100": AR_100, + "AP@50": AP_50, + "AP@75": AP_75, + "AP-S": result["S"][0], + "AR-S": result["S"][1], + "AP-M": result["M"][0], + "AR-M": result["M"][1], + "AP-L": result["L"][0], + "AR-L": result["L"][1], + "AR-1": result["1"][1], + "AR-10": result["10"][1], + } diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index c6ca4d11250..f39e3fdf743 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -17,7 +17,7 @@ import ignite.distributed as idist from ignite.engine import Engine -from ignite.metrics import ObjectDetectionAvgPrecisionRecall +from ignite.metrics import CommonObjDetectionMetrics, ObjectDetectionAvgPrecisionRecall from ignite.metrics.vision.object_detection_average_precision_recall import tensor_list_to_dict_list from ignite.utils import manual_seed @@ -891,7 +891,30 @@ def test_compute(sample): AP_L, AR_L = ignite_res[5] AR_1 = ignite_res[6][1] AR_10 = ignite_res[7][1] - assert np.allclose([AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L], sample.mAP) + all_res = [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L] + print(all_res) + assert np.allclose(all_res, sample.mAP) + + common_metrics = CommonObjDetectionMetrics(device=device) + common_metrics.update(sample.data) + res = common_metrics.compute() + common_metrics_res = [ + res["AP@50..95"], + res["AP@50"], + res["AP@75"], + res["AP-S"], + res["AP-M"], + res["AP-L"], + res["AR-1"], + res["AR-10"], + res["AR-100"], + res["AR-S"], + res["AR-M"], + res["AR-L"], + ] + print(common_metrics_res) + assert all_res == common_metrics_res + assert np.allclose(common_metrics_res, sample.mAP) def test_integration(sample): @@ -995,4 +1018,25 @@ def test_distrib_update_compute(distributed, sample): AP_L, AR_L = ignite_res[5] AR_1 = ignite_res[6][1] AR_10 = ignite_res[7][1] - assert np.allclose([AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L], sample.mAP) + all_res = [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L] + assert np.allclose(all_res, sample.mAP) + + common_metrics = CommonObjDetectionMetrics(device=device) + common_metrics.update((y_pred_rank, y_rank)) + res = common_metrics.compute() + common_metrics_res = [ + res["AP@50..95"], + res["AP@50"], + res["AP@75"], + res["AP-S"], + res["AP-M"], + res["AP-L"], + res["AR-1"], + res["AR-10"], + res["AR-100"], + res["AR-S"], + res["AR-M"], + res["AR-L"], + ] + assert all_res == common_metrics_res + assert np.allclose(common_metrics_res, sample.mAP) From cb6a328747ba3c203d914b156f36de0761416577 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 01:32:50 +0330 Subject: [PATCH 28/41] Change an annotation for the sake of M1 python3.8 --- ignite/metrics/metric_group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index 72398d464d4..f6c21a604af 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Sequence +from typing import Any, Callable, Dict, Sequence, Tuple import torch @@ -36,7 +36,7 @@ class MetricGroup(Metric): state.metrics["eval_metrics"] """ - _state_dict_all_req_keys: tuple[str, ...] = ("metrics",) + _state_dict_all_req_keys: Tuple[str, ...] = ("metrics",) def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x): self.metrics = metrics From 248fe890d35d4c4c5efc398e2ad9f76d82d28258 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 02:54:24 +0330 Subject: [PATCH 29/41] Use if check on torch.double usages for MPS backend --- ignite/metrics/mean_average_precision.py | 12 ++++++---- ...ject_detection_average_precision_recall.py | 24 ++++++++++++++----- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index 71896512a73..03f84cc8c6c 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -102,7 +102,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens ).where(rec_thresh_indices != recall.size(-1), 0) recall = rec_thresholds recall_differential = recall.diff( - dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=torch.double) + dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=recall.dtype) ) return torch.sum(recall_differential * precision, dim=-1) @@ -327,7 +327,9 @@ def _compute_recall_and_precision( `(recall, precision)` """ indices = torch.argsort(y_pred, stable=True, descending=True) - tp_summation = y_true[indices].cumsum(dim=0).double() + tp_summation = y_true[indices].cumsum(dim=0) + if tp_summation.device != torch.device("mps"): + tp_summation = tp_summation.double() # Adopted from Scikit-learn's implementation unique_scores_indices = torch.nonzero( @@ -360,8 +362,8 @@ def compute(self) -> Union[torch.Tensor, float]: torch.long if self._type == "multiclass" else torch.uint8, self._device, ) - - y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), torch.double, self._device) + fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 + y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), fp_precision, self._device) if self._type == "multiclass": y_true = to_onehot(y_true, num_classes=num_classes).T @@ -369,7 +371,7 @@ def compute(self) -> Union[torch.Tensor, float]: y_true = y_true.reshape(1, -1) y_pred = y_pred.view(1, -1) y_true_positive_count = y_true.sum(dim=-1) - average_precisions = torch.zeros_like(y_true_positive_count, device=self._device, dtype=torch.double) + average_precisions = torch.zeros_like(y_true_positive_count, device=self._device, dtype=fp_precision) for cls in range(y_true_positive_count.size(0)): recall, precision = self._compute_recall_and_precision(y_true[cls], y_pred[cls], y_true_positive_count[cls]) average_precisions[cls] = self._compute_average_precision(recall, precision) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 0ae07e924a5..f6111af1d1e 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -104,13 +104,19 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo except ImportError: raise ModuleNotFoundError("This metric requires torchvision to be installed.") + precision = torch.double if not torch.device(device) != torch.device("mps") else torch.float32 + if iou_thresholds is None: - iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double) + iou_thresholds = torch.linspace(0.5, 0.95, 10, device=device, dtype=precision) self._iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds") + self._iou_thresholds = self._iou_thresholds.to(device=device, dtype=precision) if rec_thresholds is None: - rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=torch.double) + rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=precision) + + self._rec_thresholds = self._setup_thresholds(rec_thresholds, "rec_thresholds") + self._rec_thresholds = self._rec_thresholds.to(device=device, dtype=precision) self._num_classes = num_classes self._area_range = area_range @@ -204,9 +210,14 @@ def _compute_recall_and_precision( """ indices = torch.argsort(scores, dim=-1, stable=True, descending=True) tp = TP[..., indices] - tp_summation = tp.cumsum(dim=-1).double() + tp_summation = tp.cumsum(dim=-1) + if tp_summation.device != torch.device("mps"): + tp_summation = tp_summation.double() + fp = FP[..., indices] - fp_summation = fp.cumsum(dim=-1).double() + fp_summation = fp.cumsum(dim=-1) + if fp_summation.device != torch.device("mps"): + fp_summation = fp_summation.double() recall = tp_summation / y_true_count predicted_positive = tp_summation + fp_summation @@ -342,12 +353,13 @@ def _compute(self) -> torch.Tensor: pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.long, self._device) TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) - scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), torch.double, self._device) + fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 + scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), fp_precision, self._device) average_precisions_recalls = -torch.ones( (2, self._num_classes, len(self._iou_thresholds)), device=self._device, - dtype=torch.double, + dtype=fp_precision, ) for cls in range(self._num_classes): if self._y_true_count[cls] == 0: From 8bfb8028603e8503845c621e53adce218b7b81bb Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 03:32:03 +0330 Subject: [PATCH 30/41] Fix a typo --- .../metrics/vision/object_detection_average_precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index f6111af1d1e..6676de64c51 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -104,7 +104,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo except ImportError: raise ModuleNotFoundError("This metric requires torchvision to be installed.") - precision = torch.double if not torch.device(device) != torch.device("mps") else torch.float32 + precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32 if iou_thresholds is None: iou_thresholds = torch.linspace(0.5, 0.95, 10, device=device, dtype=precision) From 4038c2bbb103ef586d010b4306679978c25f4f0c Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 04:02:25 +0330 Subject: [PATCH 31/41] Fix a bug related to tensors on same devices --- .../metrics/vision/object_detection_average_precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 6676de64c51..fc38cfd8d5f 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -317,7 +317,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor device=self._device, ) self._tps.append(tp) - self._fps.append(~tp & self._match_area_range(pred_boxes)) + self._fps.append(~tp & self._match_area_range(pred_boxes).to(self._device)) else: ious = self.box_iou(pred_boxes, gt_boxes, cast(torch.BoolTensor, gt_is_crowd)) category_no_match = labels.expand(len(pred_labels), -1) != pred_labels.view(-1, 1) From 4b6afdd4fa3b8b74dbc5156ac7a3978ee6f099d1 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 04:22:02 +0330 Subject: [PATCH 32/41] Fix a bug related to MPS and torch.double --- .../vision/object_detection_average_precision_recall.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index fc38cfd8d5f..47a06dca09f 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -345,7 +345,10 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor ) ) - self._scores.append(pred["scores"][max_best_detections_index].to(self._device)) + scores = pred["scores"][max_best_detections_index] + if self._device == torch.device("mps") and scores.dtype == torch.double: + scores = scores.to(dtype=torch.float32) + self._scores.append(scores.to(self._device)) self._y_pred_labels.append(pred_labels.to(device=self._device)) @sync_all_reduce("_y_true_count") From d0e82b3eb921a5c8181a41aaed2960cd859f6754 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 04:51:40 +0330 Subject: [PATCH 33/41] Fix a bug related to MPS --- .../vision/object_detection_average_precision_recall.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 47a06dca09f..c6f02545eb0 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -349,11 +349,11 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor if self._device == torch.device("mps") and scores.dtype == torch.double: scores = scores.to(dtype=torch.float32) self._scores.append(scores.to(self._device)) - self._y_pred_labels.append(pred_labels.to(device=self._device)) + self._y_pred_labels.append(pred_labels.to(dtype=torch.int, device=self._device)) @sync_all_reduce("_y_true_count") def _compute(self) -> torch.Tensor: - pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.long, self._device) + pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.int, self._device) TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 From 085e0df622d510099d93768918719a1541b07cff Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 05:31:35 +0330 Subject: [PATCH 34/41] Fix a bug related to MPS --- .../object_detection_average_precision_recall.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index c6f02545eb0..595cef6eb08 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -104,19 +104,13 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo except ImportError: raise ModuleNotFoundError("This metric requires torchvision to be installed.") - precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32 - if iou_thresholds is None: - iou_thresholds = torch.linspace(0.5, 0.95, 10, device=device, dtype=precision) + iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double) self._iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds") - self._iou_thresholds = self._iou_thresholds.to(device=device, dtype=precision) if rec_thresholds is None: - rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=precision) - - self._rec_thresholds = self._setup_thresholds(rec_thresholds, "rec_thresholds") - self._rec_thresholds = self._rec_thresholds.to(device=device, dtype=precision) + rec_thresholds = torch.linspace(0, 1, 101, dtype=torch.double) self._num_classes = num_classes self._area_range = area_range @@ -130,6 +124,8 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo rec_thresholds=rec_thresholds, class_mean=None, ) + precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32 + self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision) @reinit__is_reduced def reset(self) -> None: From 3658f959e8bacf0e7aafaf4ee20baa788a7f455b Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 05:46:12 +0330 Subject: [PATCH 35/41] Fix a bug related to MPS --- .../object_detection_average_precision_recall.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 595cef6eb08..a56a18c164f 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -124,7 +124,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo rec_thresholds=rec_thresholds, class_mean=None, ) - precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32 + precision = torch.double if torch.device(device).type != "mps" else torch.float32 self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision) @reinit__is_reduced @@ -207,12 +207,12 @@ def _compute_recall_and_precision( indices = torch.argsort(scores, dim=-1, stable=True, descending=True) tp = TP[..., indices] tp_summation = tp.cumsum(dim=-1) - if tp_summation.device != torch.device("mps"): + if tp_summation.device.type != "mps": tp_summation = tp_summation.double() fp = FP[..., indices] fp_summation = fp.cumsum(dim=-1) - if fp_summation.device != torch.device("mps"): + if fp_summation.device.type != "mps": fp_summation = fp_summation.double() recall = tp_summation / y_true_count @@ -342,7 +342,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor ) scores = pred["scores"][max_best_detections_index] - if self._device == torch.device("mps") and scores.dtype == torch.double: + if self._device.type == "mps" and scores.dtype == torch.double: scores = scores.to(dtype=torch.float32) self._scores.append(scores.to(self._device)) self._y_pred_labels.append(pred_labels.to(dtype=torch.int, device=self._device)) @@ -352,7 +352,7 @@ def _compute(self) -> torch.Tensor: pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.int, self._device) TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) - fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 + fp_precision = torch.double if self._device.type != "mps" else torch.float32 scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), fp_precision, self._device) average_precisions_recalls = -torch.ones( From 0444933caad1014495a182fcd3cd251eeff03e11 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 13:15:36 +0330 Subject: [PATCH 36/41] Resolve MPS's lack of cummax --- .../vision/object_detection_average_precision_recall.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index a56a18c164f..3a7ae52f6d8 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -1,3 +1,4 @@ +import os from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -125,7 +126,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo class_mean=None, ) precision = torch.double if torch.device(device).type != "mps" else torch.float32 - self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision) + self.rec_thresholds = cast(torch.Tensor, self.rec_thresholds).to(device=device, dtype=precision) @reinit__is_reduced def reset(self) -> None: @@ -234,7 +235,10 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens Returns: average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ + mps_cpu_fallback = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = mps_cpu_fallback rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) rec_thresh_indices = ( torch.searchsorted(recall, rec_thresholds) From c4337187f2cf8a399901a333e69aa2f2edcf8698 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 13:58:23 +0330 Subject: [PATCH 37/41] Revert MPS fallback --- docs/source/metrics.rst | 1 + ignite/metrics/__init__.py | 2 ++ .../object_detection_average_precision_recall.py | 10 +++------- .../ignite/metrics/vision/test_object_detection_map.py | 4 ++-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index a620ce15c64..4f643c4b234 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -342,6 +342,7 @@ Complete list of metrics MutualInformation ObjectDetectionAvgPrecisionRecall CommonObjDetectionMetrics + vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list precision.Precision PSNR recall.Recall diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index beac0b44e41..180db8326fb 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -39,6 +39,7 @@ from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy from ignite.metrics.vision.object_detection_average_precision_recall import ( + coco_tensor_list_to_dict_list, CommonObjDetectionMetrics, ObjectDetectionAvgPrecisionRecall, ) @@ -94,4 +95,5 @@ "MeanAveragePrecision", "ObjectDetectionAvgPrecisionRecall", "CommonObjDetectionMetrics", + "coco_tensor_list_to_dict_list", ] diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 3a7ae52f6d8..6c9e7b953d1 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -1,4 +1,3 @@ -import os from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -10,7 +9,7 @@ from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce -def tensor_list_to_dict_list( +def coco_tensor_list_to_dict_list( output: Tuple[ Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]], Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]], @@ -83,8 +82,8 @@ def __init__( max_detections_per_image_per_class: maximum number of detections per class in each image to consider for evaluation. The most confident ones are selected. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s - ``process_function``'s output into the form expected by the metric. An already provided example - is :func:`~ignite.metrics.vision.object_detection_average_precision_recall.tensor_list_to_dict_list` + ``process_function``'s output into the form expected by the metric. An already provided example is + :func:`~ignite.metrics.vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list` which accepts `y_pred` and `y` as lists of tensors and transforms them to the expected format. Default is the identity function. device: specifies which device updates are accumulated on. Setting the @@ -235,10 +234,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens Returns: average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ - mps_cpu_fallback = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = mps_cpu_fallback rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) rec_thresh_indices = ( torch.searchsorted(recall, rec_thresholds) diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index f39e3fdf743..94ce64da4fb 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -18,7 +18,7 @@ import ignite.distributed as idist from ignite.engine import Engine from ignite.metrics import CommonObjDetectionMetrics, ObjectDetectionAvgPrecisionRecall -from ignite.metrics.vision.object_detection_average_precision_recall import tensor_list_to_dict_list +from ignite.metrics.vision.object_detection_average_precision_recall import coco_tensor_list_to_dict_list from ignite.utils import manual_seed torch.set_printoptions(linewidth=200) @@ -957,7 +957,7 @@ def test_tensor_list_to_dict_list(): ] for y_pred in y_preds: for y in ys: - y_pred_new, y_new = tensor_list_to_dict_list((y_pred, y)) + y_pred_new, y_new = coco_tensor_list_to_dict_list((y_pred, y)) if isinstance(y_pred[0], dict): assert y_pred_new is y_pred else: From dacf4072b5fff039f07d7d81c35fc04c54742c90 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 18:27:07 +0330 Subject: [PATCH 38/41] Apply comments --- docs/source/metrics.rst | 2 +- ignite/metrics/__init__.py | 4 ++-- .../object_detection_average_precision_recall.py | 14 +++++++++++--- .../metrics/vision/test_object_detection_map.py | 6 +++--- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 4f643c4b234..1bdddb8a4a6 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -341,7 +341,7 @@ Complete list of metrics MultiLabelConfusionMatrix MutualInformation ObjectDetectionAvgPrecisionRecall - CommonObjDetectionMetrics + CommonObjectDetectionMetrics vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list precision.Precision PSNR diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 180db8326fb..26c105d4ed7 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -40,7 +40,7 @@ from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy from ignite.metrics.vision.object_detection_average_precision_recall import ( coco_tensor_list_to_dict_list, - CommonObjDetectionMetrics, + CommonObjectDetectionMetrics, ObjectDetectionAvgPrecisionRecall, ) @@ -94,6 +94,6 @@ "ROC_AUC", "MeanAveragePrecision", "ObjectDetectionAvgPrecisionRecall", - "CommonObjDetectionMetrics", + "CommonObjectDetectionMetrics", "coco_tensor_list_to_dict_list", ] diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 6c9e7b953d1..2d8403c1c27 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -234,7 +234,15 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens Returns: average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. """ - precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) + if precision.device.type == "mps": + # Manual fallback to CPU if precision is on MPS due to the error: + # NotImplementedError: The operator 'aten::_cummax_helper' is not currently implemented for the MPS device + device = precision.device + precision_integrand = precision.flip(-1).cpu() + precision_integrand = precision_integrand.cummax(dim=-1).values + precision_integrand = precision_integrand.to(device=device).flip(-1) + else: + precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) rec_thresh_indices = ( torch.searchsorted(recall, rec_thresholds) @@ -386,9 +394,9 @@ def compute(self) -> Tuple[float, float]: return ap, ar -class CommonObjDetectionMetrics(MetricGroup): +class CommonObjectDetectionMetrics(MetricGroup): """ - Common Object detection metrics. Included metrics are as follows: + Common Object Detection metrics. Included metrics are as follows: =============== ========================================== **Metric name** **Description** diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 94ce64da4fb..4141ed73981 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -17,7 +17,7 @@ import ignite.distributed as idist from ignite.engine import Engine -from ignite.metrics import CommonObjDetectionMetrics, ObjectDetectionAvgPrecisionRecall +from ignite.metrics import CommonObjectDetectionMetrics, ObjectDetectionAvgPrecisionRecall from ignite.metrics.vision.object_detection_average_precision_recall import coco_tensor_list_to_dict_list from ignite.utils import manual_seed @@ -895,7 +895,7 @@ def test_compute(sample): print(all_res) assert np.allclose(all_res, sample.mAP) - common_metrics = CommonObjDetectionMetrics(device=device) + common_metrics = CommonObjectDetectionMetrics(device=device) common_metrics.update(sample.data) res = common_metrics.compute() common_metrics_res = [ @@ -1021,7 +1021,7 @@ def test_distrib_update_compute(distributed, sample): all_res = [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L] assert np.allclose(all_res, sample.mAP) - common_metrics = CommonObjDetectionMetrics(device=device) + common_metrics = CommonObjectDetectionMetrics(device=device) common_metrics.update((y_pred_rank, y_rank)) res = common_metrics.compute() common_metrics_res = [ From 67e38c406637b2e747a27f70e8772a486d87ef98 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 6 Sep 2024 00:29:09 +0330 Subject: [PATCH 39/41] Revert unnecessary changes --- tests/ignite/distributed/utils/__init__.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 4b33bf0ec92..1f3ad55dd84 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -6,8 +6,6 @@ from ignite.distributed.utils import all_gather_tensors_with_shapes, sync from ignite.engine import Engine, Events -torch.manual_seed(41) - def _sanity_check(): from ignite.distributed.utils import _model @@ -195,11 +193,6 @@ def _test_distrib_all_gather(device): true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1) assert (res == true_res).all() - ts = [torch.randn((r + 1, r + 2, r + 3), device=device) for r in range(ws)] - ts_gathered = all_gather_tensors_with_shapes(ts[rank], [list(t.shape) for t in ts]) - for t, t_gathered in zip(ts, ts_gathered): - assert (t == t_gathered).all() - if ws > 1 and idist.backend() != "xla-tpu": t = { "a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)], @@ -226,7 +219,7 @@ def _test_distrib_all_gather(device): def _test_distrib_all_gather_group(device): if idist.get_world_size() > 1: - ranks = list(range(1, idist.get_world_size())) + ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1] rank = idist.get_rank() bnd = idist.backend() @@ -253,15 +246,6 @@ def _test_distrib_all_gather_group(device): else: assert res == t - ts = [torch.randn((i + 1, i + 2, i + 3), device=device) for i in range(idist.get_world_size())] - shapes = [list(t.shape) for r, t in enumerate(ts) if r in ranks] - ts_gathered = all_gather_tensors_with_shapes(ts[rank], shapes, ranks) - if rank in ranks: - for i, r in enumerate(ranks): - assert (ts[r] == ts_gathered[i]).all() - else: - assert ts_gathered == [ts[rank]] - t = { "a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)], "b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device), @@ -437,7 +421,7 @@ def _test_distrib_new_group(device): if rank in ranks: assert g1.rank() == g2.rank() elif idist.has_xla_support and bnd in ("xla-tpu"): - assert idist.new_group(ranks) == ranks + assert idist.new_group(ranks) == [ranks] elif idist.has_hvd_support and bnd in ("horovod"): from horovod.common.process_sets import ProcessSet From 7b43c69134b601c8bb18916f17c76549430151f5 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 20 Sep 2024 21:22:42 +0330 Subject: [PATCH 40/41] Apply review comments --- ignite/metrics/mean_average_precision.py | 9 +++ ignite/metrics/metric_group.py | 14 ++++- ...ject_detection_average_precision_recall.py | 33 ++++++++--- .../vision/test_object_detection_map.py | 55 +++++++++++-------- 4 files changed, 78 insertions(+), 33 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index 03f84cc8c6c..d82505b5446 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -149,6 +149,7 @@ def __init__( is_multilabel: bool = False, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, ) -> None: r"""Calculate the mean average precision metric i.e. mean of the averaged-over-recall precision for classification task: @@ -210,12 +211,20 @@ def __init__( device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as + ``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for + ``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle + this. + + .. versionadded:: 0.5.2 """ super(MeanAveragePrecision, self).__init__( output_transform=output_transform, is_multilabel=is_multilabel, device=device, + skip_unrolling=skip_unrolling, ) super(Metric, self).__init__(rec_thresholds=rec_thresholds, class_mean=class_mean) diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index f6c21a604af..bd4ffbc77ef 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -15,6 +15,11 @@ class MetricGroup(Metric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. `output_transform` of each metric in the group is also called upon its update. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as + ``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for + ``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle + this. Examples: We construct a group of metrics, attach them to the engine at once and retrieve their result. @@ -34,13 +39,18 @@ class MetricGroup(Metric): # And also altogether state.metrics["eval_metrics"] + + .. versionchanged:: 0.5.2 + ``skip_unrolling`` argument is added. """ _state_dict_all_req_keys: Tuple[str, ...] = ("metrics",) - def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x): + def __init__( + self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x, skip_unrolling: bool = False + ): self.metrics = metrics - super(MetricGroup, self).__init__(output_transform=output_transform) + super(MetricGroup, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling) def reset(self) -> None: for m in self.metrics.values(): diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 2d8403c1c27..881c9ccc5ed 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -52,11 +52,12 @@ def __init__( self, iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - num_classes: int = 91, + num_classes: int = 80, max_detections_per_image_per_class: int = 100, area_range: Literal["small", "medium", "large", "all"] = "all", output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, ) -> None: r"""Calculate mean average precision & recall for evaluating an object detector in the COCO way. @@ -77,7 +78,7 @@ def __init__( Values should be between 0 and 1. If not given, COCO's default (.5, .55, ..., .95) would be used. rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. Values should be between 0 and 1. If not given, COCO's default (.0, .01, .02, ..., 1.) would be used. - num_classes: number of categories. Default is 91, that of the COCO. + num_classes: number of categories. Default is 80, that of the COCO dataset. area_range: area range which only objects therein are considered in evaluation. By default, 'all'. max_detections_per_image_per_class: maximum number of detections per class in each image to consider for evaluation. The most confident ones are selected. @@ -89,6 +90,13 @@ def __init__( device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as + ``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for + ``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle + this. + + .. versionadded:: 0.5.2 """ try: from torchvision.ops.boxes import _box_inter_union, box_area @@ -119,6 +127,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo super(ObjectDetectionAvgPrecisionRecall, self).__init__( output_transform=output_transform, device=device, + skip_unrolling=skip_unrolling, ) super(Metric, self).__init__( rec_thresholds=rec_thresholds, @@ -420,6 +429,7 @@ class CommonObjectDetectionMetrics(MetricGroup): AR-10 Average recall with maximum 10 detections =============== ========================================== + .. versionadded:: 0.5.2 """ _state_dict_all_req_keys = ("metrics", "ap_50_95") @@ -428,20 +438,27 @@ class CommonObjectDetectionMetrics(MetricGroup): def __init__( self, + num_classes: int = 80, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = True, ): - self.ap_50_95 = ObjectDetectionAvgPrecisionRecall(device=device) + self.ap_50_95 = ObjectDetectionAvgPrecisionRecall(num_classes=num_classes, device=device) super().__init__( { - "S": ObjectDetectionAvgPrecisionRecall(device=device, area_range="small"), - "M": ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium"), - "L": ObjectDetectionAvgPrecisionRecall(device=device, area_range="large"), - "1": ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image_per_class=1), - "10": ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image_per_class=10), + "S": ObjectDetectionAvgPrecisionRecall(num_classes=num_classes, device=device, area_range="small"), + "M": ObjectDetectionAvgPrecisionRecall(num_classes=num_classes, device=device, area_range="medium"), + "L": ObjectDetectionAvgPrecisionRecall(num_classes=num_classes, device=device, area_range="large"), + "1": ObjectDetectionAvgPrecisionRecall( + num_classes=num_classes, device=device, max_detections_per_image_per_class=1 + ), + "10": ObjectDetectionAvgPrecisionRecall( + num_classes=num_classes, device=device, max_detections_per_image_per_class=10 + ), }, output_transform, + skip_unrolling=skip_unrolling, ) def reset(self) -> None: diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index 4141ed73981..de41d317e56 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -423,7 +423,7 @@ def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch h = 640 * torch.rand((n_gt_box, 1)) x2 = (x1 + w).clip(max=640) y2 = (y1 + h).clip(max=640) - category = torch.randint(91, (n_gt_box, 1)) + category = torch.randint(0, 80, (n_gt_box, 1)) iscrowd = torch.randint(2, (n_gt_box, 1)) targets.append(torch.cat((x1, y1, x2, y2, category, iscrowd), dim=1)) @@ -449,7 +449,7 @@ def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch h = (h + perturb_h).clip(min=0, max=640) x2 = (x1 + w).clip(max=640) y2 = (y1 + h).clip(max=640) - category = (category + perturb_category) % 100 + category = (category + perturb_category) % 80 confidence = torch.rand_like(category, dtype=torch.double) perturbed_gt_boxes = torch.cat((x1, y1, x2, y2, confidence, category), dim=1) @@ -461,7 +461,7 @@ def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch h = 640 * torch.rand((n_additional_pred_boxes, 1)) x2 = (x1 + w).clip(max=640) y2 = (y1 + h).clip(max=640) - category = torch.randint(100, (n_additional_pred_boxes, 1)) + category = torch.randint(0, 80, (n_additional_pred_boxes, 1)) confidence = torch.rand_like(category, dtype=torch.double) additional_pred_boxes = torch.cat((x1, y1, x2, y2, confidence, category), dim=1) @@ -690,7 +690,7 @@ def test_no_torchvision(): def test_iou(sample): - m = ObjectDetectionAvgPrecisionRecall() + m = ObjectDetectionAvgPrecisionRecall(num_classes=91) from pycocotools.mask import iou as pycoco_iou for pred, tgt in zip(*sample.data): @@ -864,15 +864,18 @@ def test__compute_recall_and_precision(): def test_compute(sample): device = idist.device() + if device == torch.device("mps"): + pytest.skip("Due to MPS backend out of memory") + # AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L - ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(device=device) - ap_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=device) - ap_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=device) - ap_ar_S = ObjectDetectionAvgPrecisionRecall(device=device, area_range="small") - ap_ar_M = ObjectDetectionAvgPrecisionRecall(device=device, area_range="medium") - ap_ar_L = ObjectDetectionAvgPrecisionRecall(device=device, area_range="large") - ar_1 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image_per_class=1) - ar_10 = ObjectDetectionAvgPrecisionRecall(device=device, max_detections_per_image_per_class=10) + ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device) + ap_50 = ObjectDetectionAvgPrecisionRecall(num_classes=91, iou_thresholds=[0.5], device=device) + ap_75 = ObjectDetectionAvgPrecisionRecall(num_classes=91, iou_thresholds=[0.75], device=device) + ap_ar_S = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, area_range="small") + ap_ar_M = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, area_range="medium") + ap_ar_L = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, area_range="large") + ar_1 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, max_detections_per_image_per_class=1) + ar_10 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, max_detections_per_image_per_class=10) metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] for metric in metrics: @@ -895,7 +898,7 @@ def test_compute(sample): print(all_res) assert np.allclose(all_res, sample.mAP) - common_metrics = CommonObjectDetectionMetrics(device=device) + common_metrics = CommonObjectDetectionMetrics(num_classes=91, device=device) common_metrics.update(sample.data) res = common_metrics.compute() common_metrics_res = [ @@ -928,7 +931,7 @@ def update(engine, i): device = idist.device() metric_device = "cpu" if device.type == "xla" else device - metric_50_95 = ObjectDetectionAvgPrecisionRecall(device=metric_device) + metric_50_95 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device) metric_50_95.attach(engine, name="mAP[50-95]") n_iter = ceil(sample.length / bs) @@ -988,16 +991,22 @@ def test_distrib_update_compute(distributed, sample): rank_samples_range = slice(rank_samples_cnt * rank, rank_samples_cnt * (rank + 1)) device = idist.device() + + if device == torch.device("mps"): + pytest.skip("Due to MPS backend out of memory") + metric_device = "cpu" if device.type == "xla" else device # AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L - ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(device=metric_device) - ap_50 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.5], device=metric_device) - ap_75 = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.75], device=metric_device) - ap_ar_S = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="small") - ap_ar_M = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="medium") - ap_ar_L = ObjectDetectionAvgPrecisionRecall(device=metric_device, area_range="large") - ar_1 = ObjectDetectionAvgPrecisionRecall(device=metric_device, max_detections_per_image_per_class=1) - ar_10 = ObjectDetectionAvgPrecisionRecall(device=metric_device, max_detections_per_image_per_class=10) + ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device) + ap_50 = ObjectDetectionAvgPrecisionRecall(num_classes=91, iou_thresholds=[0.5], device=metric_device) + ap_75 = ObjectDetectionAvgPrecisionRecall(num_classes=91, iou_thresholds=[0.75], device=metric_device) + ap_ar_S = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device, area_range="small") + ap_ar_M = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device, area_range="medium") + ap_ar_L = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device, area_range="large") + ar_1 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device, max_detections_per_image_per_class=1) + ar_10 = ObjectDetectionAvgPrecisionRecall( + num_classes=91, device=metric_device, max_detections_per_image_per_class=10 + ) metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] @@ -1021,7 +1030,7 @@ def test_distrib_update_compute(distributed, sample): all_res = [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L] assert np.allclose(all_res, sample.mAP) - common_metrics = CommonObjectDetectionMetrics(device=device) + common_metrics = CommonObjectDetectionMetrics(num_classes=91, device=device) common_metrics.update((y_pred_rank, y_rank)) res = common_metrics.compute() common_metrics_res = [ From 954d1306c05b2d94a3993fec0e7ae33438129f40 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Sun, 29 Sep 2024 23:45:14 +0330 Subject: [PATCH 41/41] Skip MPS on test_integraion as well --- tests/ignite/metrics/vision/test_object_detection_map.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py index de41d317e56..712b2fdebdf 100644 --- a/tests/ignite/metrics/vision/test_object_detection_map.py +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -923,13 +923,16 @@ def test_compute(sample): def test_integration(sample): bs = 3 + device = idist.device() + if device == torch.device("mps"): + pytest.skip("Due to MPS backend out of memory") + def update(engine, i): b = slice(i * bs, (i + 1) * bs) return sample.data[0][b], sample.data[1][b] engine = Engine(update) - device = idist.device() metric_device = "cpu" if device.type == "xla" else device metric_50_95 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device) metric_50_95.attach(engine, name="mAP[50-95]")