diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 29943e98343..64f30537321 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -332,6 +332,7 @@ Complete list of metrics Frequency Loss MeanAbsoluteError + MeanAveragePrecision MeanPairwiseDistance MeanSquaredError metric.Metric @@ -339,6 +340,9 @@ Complete list of metrics metrics_lambda.MetricsLambda MultiLabelConfusionMatrix MutualInformation + ObjectDetectionAvgPrecisionRecall + CommonObjectDetectionMetrics + vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list precision.Precision PSNR recall.Recall diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 9f2c2303bc8..6d84ff0ccf4 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -20,6 +20,7 @@ from ignite.metrics.loss import Loss from ignite.metrics.maximum_mean_discrepancy import MaximumMeanDiscrepancy from ignite.metrics.mean_absolute_error import MeanAbsoluteError +from ignite.metrics.mean_average_precision import MeanAveragePrecision from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance from ignite.metrics.mean_squared_error import MeanSquaredError from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage @@ -38,6 +39,11 @@ from ignite.metrics.running_average import RunningAverage from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy +from ignite.metrics.vision.object_detection_average_precision_recall import ( + coco_tensor_list_to_dict_list, + CommonObjectDetectionMetrics, + ObjectDetectionAvgPrecisionRecall, +) __all__ = [ "Metric", @@ -88,4 +94,8 @@ "PrecisionRecallCurve", "RocCurve", "ROC_AUC", + "MeanAveragePrecision", + "ObjectDetectionAvgPrecisionRecall", + "CommonObjectDetectionMetrics", + "coco_tensor_list_to_dict_list", ] diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py new file mode 100644 index 00000000000..d82505b5446 --- /dev/null +++ b/ignite/metrics/mean_average_precision.py @@ -0,0 +1,394 @@ +import warnings +from typing import Callable, cast, List, Optional, Sequence, Tuple, Union + +import torch +from typing_extensions import Literal + +import ignite.distributed as idist +from ignite.distributed.utils import all_gather_tensors_with_shapes +from ignite.metrics.metric import Metric, reinit__is_reduced +from ignite.metrics.precision import _BaseClassification +from ignite.utils import to_onehot + + +class _BaseAveragePrecision: + def __init__( + self, + rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + class_mean: Optional[Literal["micro", "macro", "weighted"]] = "macro", + ) -> None: + r"""Base class for Average Precision metric. + + This class contains the methods for setting up the thresholds and computing AP & AR. + + Args: + rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. + It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need + to be sorted. If missing, thresholds are considered automatically using the data. + class_mean: how to compute mean of the average precision across classes or incorporate class + dimension into computing precision. It's ignored in binary classification. Available options are + + None + An 1-dimensional tensor of mean (taken across additional mean dimensions) average precision per class + is returned. If there's no ground truth sample for a class, ``0`` is returned for that. + + 'micro' + Precision is computed counting stats of classes/labels altogether. This option + incorporates class in the very precision measurement. + + .. math:: + \text{Micro P} = \frac{\sum_{c=1}^C TP_c}{\sum_{c=1}^C TP_c+FP_c} + + where :math:`C` is the number of classes/labels. :math:`c` in :math:`TP_c` + and :math:`FP_c` means that the terms are computed for class/label :math:`c` (in a one-vs-rest + sense in multiclass case). + + For multiclass inputs, this is equivalent with mean average accuracy. + + 'weighted' + like macro but considers class/label imbalance. For multiclass input, + it computes AP for each class then returns mean of them weighted by + support of classes (number of actual samples in each class). For multilabel input, + it computes AP for each label then returns mean of them weighted by support + of labels (number of actual positive samples in each label). + + 'macro' + computes macro precision which is unweighted mean of AP computed across classes/labels. Default. + """ + if rec_thresholds is not None: + self.rec_thresholds: Optional[torch.Tensor] = self._setup_thresholds(rec_thresholds, "rec_thresholds") + else: + self.rec_thresholds = None + + if class_mean is not None and class_mean not in ("micro", "macro", "weighted"): + raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") + self.class_mean = class_mean + + def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor: + if isinstance(thresholds, Sequence): + thresholds = torch.tensor(thresholds, dtype=torch.double) + + if isinstance(thresholds, torch.Tensor): + if thresholds.ndim != 1: + raise ValueError( + f"{threshold_type} should be a one-dimensional tensor or a sequence of floats" + f", given a {thresholds.ndim}-dimensional tensor." + ) + thresholds = thresholds.sort().values + else: + raise TypeError(f"{threshold_type} should be a sequence of floats or a tensor, given {type(thresholds)}.") + + if min(thresholds) < 0 or max(thresholds) > 1: + raise ValueError(f"{threshold_type} values should be between 0 and 1, given {thresholds}") + + return thresholds + + def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: + """Measuring average precision. + + Args: + recall: n-dimensional tensor whose last dimension represents confidence thresholds as much as #samples. + Should be ordered in ascending order in its last dimension. + precision: like ``recall`` in the shape. + + Returns: + average_precision: (n-1)-dimensional tensor containing the average precisions. + """ + if self.rec_thresholds is not None: + rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1)) + rec_thresh_indices = torch.searchsorted(recall, rec_thresholds) + precision = precision.take_along_dim( + rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 + ).where(rec_thresh_indices != recall.size(-1), 0) + recall = rec_thresholds + recall_differential = recall.diff( + dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=recall.dtype) + ) + return torch.sum(recall_differential * precision, dim=-1) + + +def _cat_and_agg_tensors( + tensors: List[torch.Tensor], + tensor_shape_except_last_dim: Tuple[int], + dtype: torch.dtype, + device: Union[str, torch.device], +) -> torch.Tensor: + """ + Concatenate tensors in ``tensors`` at their last dimension and gather all tensors from across all processes. + All tensors should have the same shape (denoted by ``tensor_shape_except_last_dim``) except at their + last dimension. + """ + num_preds = torch.tensor( + [sum([tensor.shape[-1] for tensor in tensors]) if tensors else 0], + device=device, + ) + all_num_preds = cast(torch.Tensor, idist.all_gather(num_preds)).tolist() + tensor = ( + torch.cat(tensors, dim=-1) + if tensors + else torch.empty((*tensor_shape_except_last_dim, 0), dtype=dtype, device=device) + ) + shape_across_ranks = [(*tensor_shape_except_last_dim, num_pred_in_rank) for num_pred_in_rank in all_num_preds] + return torch.cat( + all_gather_tensors_with_shapes( + tensor, + shape_across_ranks, + ), + dim=-1, + ) + + +class MeanAveragePrecision(_BaseClassification, _BaseAveragePrecision): + _y_pred: List[torch.Tensor] + _y_true: List[torch.Tensor] + + def __init__( + self, + rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + class_mean: Optional["Literal['micro', 'macro', 'weighted']"] = "macro", + is_multilabel: bool = False, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, + ) -> None: + r"""Calculate the mean average precision metric i.e. mean of the averaged-over-recall precision for + classification task: + + .. math:: + \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) P_k + + Mean average precision attempts to give a measure of detector or classifier precision at various + sensivity levels a.k.a recall thresholds. This is done by summing precisions at different recall + thresholds weighted by the change in recall, as if the area under precision-recall curve is being computed. + Mean average precision is then computed by taking the mean of this average precision over different classes. + + All the binary, multiclass and multilabel data are supported. In the latter case, + ``is_multilabel`` should be set to true. + + `mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean`` + determines how to take this mean. + + Args: + rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision. + It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need + to be sorted. If missing, thresholds are considered automatically using the data. + class_mean: how to compute mean of the average precision across classes or incorporate class + dimension into computing precision. It's ignored in binary classification. Available options are + + None + A 1-dimensional tensor of mean (taken across additional mean dimensions) average precision per class + is returned. If there's no ground truth sample for a class, ``0`` is returned for that. + + 'micro' + Precision is computed counting stats of classes/labels altogether. This option + incorporates class in the very precision measurement. + + .. math:: + \text{Micro P} = \frac{\sum_{c=1}^C TP_c}{\sum_{c=1}^C TP_c+FP_c} + + where :math:`C` is the number of classes/labels. :math:`c` in :math:`TP_c` + and :math:`FP_c` means that the terms are computed for class/label :math:`c` (in a one-vs-rest + sense in multiclass case). + + For multiclass inputs, this is equivalent to mean average accuracy. + + 'weighted' + like macro but considers class/label imbalance. For multiclass input, + it computes AP for each class then returns mean of them weighted by + support of classes (number of actual samples in each class). For multilabel input, + it computes AP for each label then returns mean of them weighted by support + of labels (number of actual positive samples in each label). + + 'macro' + computes macro precision which is unweighted mean of AP computed across classes/labels. Default. + + is_multilabel: determines if the data is multilabel or not. Default False. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. This metric requires the output + as ``(y_pred, y)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as + ``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for + ``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle + this. + + .. versionadded:: 0.5.2 + """ + + super(MeanAveragePrecision, self).__init__( + output_transform=output_transform, + is_multilabel=is_multilabel, + device=device, + skip_unrolling=skip_unrolling, + ) + super(Metric, self).__init__(rec_thresholds=rec_thresholds, class_mean=class_mean) + + @reinit__is_reduced + def reset(self) -> None: + """ + Reset method of the metric + """ + super().reset() + self._y_pred = [] + self._y_true = [] + + def _check_binary_multilabel_cases(self, output: Sequence[torch.Tensor]) -> None: + # Ignore the check in `_BaseClassification` since `y_pred` consists of probabilities here. + _, y = output + if not torch.equal(y, y**2): + raise ValueError("For binary cases, y must be comprised of 0's and 1's.") + + def _check_type(self, output: Sequence[torch.Tensor]) -> None: + super()._check_type(output) + y_pred, y = output + if y_pred.dtype in (torch.int, torch.long): + raise TypeError(f"`y_pred` should be a float tensor, given {y_pred.dtype}") + if self._type == "multiclass" and y.dtype != torch.long: + warnings.warn("`y` should be of dtype long when entry type is multiclass", RuntimeWarning) + + def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tensor]: + """Prepares and returns ``y_pred`` and ``y`` tensors. Input and output shapes of the method is as follows. + ``C`` and ``L`` denote the number of classes and labels in multiclass and multilabel inputs respectively. + + ========== =========== ============ + ``y_pred`` + ----------------------------------- + Data type Input shape Output shape + ========== =========== ============ + Binary (N, ...) (1, N * ...) + Multilabel (N, L, ...) (L, N * ...) + Multiclass (N, C, ...) (C, N * ...) + ========== =========== ============ + + ========== =========== ============ + ``y`` + ----------------------------------- + Data type Input shape Output shape + ========== =========== ============ + Binary (N, ...) (1, N * ...) + Multilabel (N, L, ...) (L, N * ...) + Multiclass (N, ...) (N * ...) + ========== =========== ============ + """ + y_pred, y = output[0].detach(), output[1].detach() + + if self._type == "multilabel": + num_classes = y_pred.size(1) + yp = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) + yt = torch.transpose(y, 1, 0).reshape(num_classes, -1) + elif self._type == "binary": + yp = y_pred.view(1, -1) + yt = y.view(1, -1) + else: # Multiclass + num_classes = y_pred.size(1) + if y.max() + 1 > num_classes: + raise ValueError( + f"y_pred contains fewer classes than y. Number of classes in prediction is {num_classes}" + f" and an element in y has invalid class = {y.max().item() + 1}." + ) + yt = y.view(-1) + yp = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) + + return yp, yt + + @reinit__is_reduced + def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: + """Metric update function using prediction and target. + + Args: + output: a binary tuple consisting of prediction and target tensors + + This metric follows the same rules on ``output`` members shape as the + :meth:`Precision.update <.metrics.precision.Precision.update>` except for ``y_pred`` of binary and + multilabel data which should be comprised of positive class probabilities here. + """ + self._check_shape(output) + self._check_type(output) + yp, yt = self._prepare_output(output) + self._y_pred.append(yp.to(self._device)) + self._y_true.append(yt.to(self._device, dtype=torch.uint8 if self._type != "multiclass" else torch.long)) + + def _compute_recall_and_precision( + self, y_true: torch.Tensor, y_pred: torch.Tensor, y_true_positive_count: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Measuring recall & precision. + + Shape of function inputs and return values follow the table below. + N is the number of samples. \#unique scores represents number of + unique scores in ``scores`` which is actually the number of thresholds. + + ===================== ======================================= + **Object** **Shape** + ===================== ======================================= + y_true (N,) + y_pred (N,) + y_true_positive_count () (A single float) + recall (\#unique scores,) + precision (\#unique scores,) + ===================== ======================================= + + Returns: + `(recall, precision)` + """ + indices = torch.argsort(y_pred, stable=True, descending=True) + tp_summation = y_true[indices].cumsum(dim=0) + if tp_summation.device != torch.device("mps"): + tp_summation = tp_summation.double() + + # Adopted from Scikit-learn's implementation + unique_scores_indices = torch.nonzero( + y_pred[indices].diff(append=(y_pred.max() + 1).unsqueeze(dim=0)), as_tuple=True + )[0] + tp_summation = tp_summation[..., unique_scores_indices] + fp_summation = (unique_scores_indices + 1) - tp_summation + + if y_true_positive_count == 0: + # To be aligned with Scikit-Learn + recall = torch.ones_like(tp_summation, device=self._device, dtype=torch.float) + else: + recall = tp_summation / y_true_positive_count + + predicted_positive = tp_summation + fp_summation + precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) + return recall, precision + + def compute(self) -> Union[torch.Tensor, float]: + """ + Compute method of the metric + """ + if self._num_classes is None: + raise RuntimeError("Metric could not be computed without any update method call") + num_classes = self._num_classes + + y_true = _cat_and_agg_tensors( + self._y_true, + cast(Tuple[int], ()) if self._type == "multiclass" else (num_classes,), + torch.long if self._type == "multiclass" else torch.uint8, + self._device, + ) + fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 + y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), fp_precision, self._device) + + if self._type == "multiclass": + y_true = to_onehot(y_true, num_classes=num_classes).T + if self.class_mean == "micro": + y_true = y_true.reshape(1, -1) + y_pred = y_pred.view(1, -1) + y_true_positive_count = y_true.sum(dim=-1) + average_precisions = torch.zeros_like(y_true_positive_count, device=self._device, dtype=fp_precision) + for cls in range(y_true_positive_count.size(0)): + recall, precision = self._compute_recall_and_precision(y_true[cls], y_pred[cls], y_true_positive_count[cls]) + average_precisions[cls] = self._compute_average_precision(recall, precision) + if self._type == "binary": + return average_precisions.item() + if self.class_mean is None: + return average_precisions + elif self.class_mean == "weighted": + return torch.sum(y_true_positive_count * average_precisions) / y_true_positive_count.sum() + else: + return average_precisions.mean() diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 453fb1291e9..b8d35c3184e 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -873,7 +873,7 @@ def _is_list_of_tensors_or_numbers(x: Sequence[Union[torch.Tensor, float]]) -> b return isinstance(x, Sequence) and all([isinstance(t, (torch.Tensor, Number)) for t in x]) -def _to_batched_tensor(x: Union[torch.Tensor, float], device: Optional[torch.device] = None) -> torch.Tensor: +def _to_batched_tensor(x: Union[torch.Tensor, Number], device: Optional[torch.device] = None) -> torch.Tensor: if isinstance(x, torch.Tensor): return x.unsqueeze(dim=0) return torch.tensor([x], device=device) diff --git a/ignite/metrics/metric_group.py b/ignite/metrics/metric_group.py index 58a52f658ae..bd4ffbc77ef 100644 --- a/ignite/metrics/metric_group.py +++ b/ignite/metrics/metric_group.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Sequence +from typing import Any, Callable, Dict, Sequence, Tuple import torch @@ -15,6 +15,11 @@ class MetricGroup(Metric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. `output_transform` of each metric in the group is also called upon its update. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as + ``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for + ``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle + this. Examples: We construct a group of metrics, attach them to the engine at once and retrieve their result. @@ -34,13 +39,18 @@ class MetricGroup(Metric): # And also altogether state.metrics["eval_metrics"] + + .. versionchanged:: 0.5.2 + ``skip_unrolling`` argument is added. """ - _state_dict_all_req_keys = ("metrics",) + _state_dict_all_req_keys: Tuple[str, ...] = ("metrics",) - def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x): + def __init__( + self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x, skip_unrolling: bool = False + ): self.metrics = metrics - super(MetricGroup, self).__init__(output_transform=output_transform) + super(MetricGroup, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling) def reset(self) -> None: for m in self.metrics.values(): diff --git a/ignite/metrics/vision/__init__.py b/ignite/metrics/vision/__init__.py new file mode 100644 index 00000000000..b5b5d0236bb --- /dev/null +++ b/ignite/metrics/vision/__init__.py @@ -0,0 +1,3 @@ +from ignite.metrics.vision.object_detection_average_precision_recall import ObjectDetectionAvgPrecisionRecall + +__all__ = ["ObjectDetectionAvgPrecisionRecall"] diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py new file mode 100644 index 00000000000..881c9ccc5ed --- /dev/null +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -0,0 +1,499 @@ +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from typing_extensions import Literal + +from ignite.metrics import MetricGroup + +from ignite.metrics.mean_average_precision import _BaseAveragePrecision, _cat_and_agg_tensors +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + + +def coco_tensor_list_to_dict_list( + output: Tuple[ + Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]], + Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]], + ] +) -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]: + """Convert either of output's `y_pred` or `y` from list of `(N, 6)` tensors to list of str-to-tensor dictionaries, + or keep them unchanged if they're already in the deisred format. + + Input format is a `(N, 6)` or (`N, 5)` tensor which `N` is the number of predicted/target bounding boxes for the + image and the second dimension contains `(x1, y1, x2, y2, confidence, class)`/`(x1, y1, x2, y2, class[, iscrowd])`. + Output format is a str-to-tensor dictionary containing 'bbox' and `class` keys, plus `confidence` key for `y_pred` + and possibly `iscrowd` for `y`. + + Args: + output: `(y_pred,y)` tuple whose members are either list of tensors or list of dicts. + + Returns: + `(y_pred,y)` tuple whose members are list of str-to-tensor dictionaries. + """ + y_pred, y = output + if len(y_pred) > 0 and isinstance(y_pred[0], torch.Tensor): + y_pred = [{"bbox": t[:, :4], "confidence": t[:, 4], "class": t[:, 5]} for t in cast(List[torch.Tensor], y_pred)] + if len(y) > 0 and isinstance(y[0], torch.Tensor): + if y[0].size(1) == 5: + y = [{"bbox": t[:, :4], "class": t[:, 4]} for t in cast(List[torch.Tensor], y)] + else: + y = [{"bbox": t[:, :4], "class": t[:, 4], "iscrowd": t[:, 5]} for t in cast(List[torch.Tensor], y)] + return cast(Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]], (y_pred, y)) + + +class ObjectDetectionAvgPrecisionRecall(Metric, _BaseAveragePrecision): + _tps: List[torch.Tensor] + _fps: List[torch.Tensor] + _scores: List[torch.Tensor] + _y_pred_labels: List[torch.Tensor] + _y_true_count: torch.Tensor + _num_classes: int + + def __init__( + self, + iou_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, + num_classes: int = 80, + max_detections_per_image_per_class: int = 100, + area_range: Literal["small", "medium", "large", "all"] = "all", + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, + ) -> None: + r"""Calculate mean average precision & recall for evaluating an object detector in the COCO way. + + + Average precision is computed by averaging precision over increasing levels of recall thresholds. + In COCO, the maximum precision across thresholds greater or equal a recall threshold is + considered as the average summation operand; In other words, the precision peek across lower or equal + sensivity levels is used for a recall threshold: + + .. math:: + \text{Average Precision} = \sum_{k=1}^{\#rec\_thresholds} (r_k - r_{k-1}) max(P_{k:}) + + Average recall is the detector's maximum recall, considering all matched detections as TP, + averaged over classes. + + Args: + iou_thresholds: sequence of IoU thresholds to be considered for computing mean average precision & recall. + Values should be between 0 and 1. If not given, COCO's default (.5, .55, ..., .95) would be used. + rec_thresholds: sequence of recall thresholds to be considered for computing mean average precision. + Values should be between 0 and 1. If not given, COCO's default (.0, .01, .02, ..., 1.) would be used. + num_classes: number of categories. Default is 80, that of the COCO dataset. + area_range: area range which only objects therein are considered in evaluation. By default, 'all'. + max_detections_per_image_per_class: maximum number of detections per class in each image to consider + for evaluation. The most confident ones are selected. + output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s + ``process_function``'s output into the form expected by the metric. An already provided example is + :func:`~ignite.metrics.vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list` + which accepts `y_pred` and `y` as lists of tensors and transforms them to the expected format. + Default is the identity function. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` and ``y`` contain multi-ouput as + ``(y_pred_a, y_pred_b)`` and ``(y_a, y_b)``, in which case the update method is called for + ``(y_pred_a, y_a)`` and ``(y_pred_b, y_b)``.Alternatively, ``output_transform`` can be used to handle + this. + + .. versionadded:: 0.5.2 + """ + try: + from torchvision.ops.boxes import _box_inter_union, box_area + + def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.BoolTensor) -> torch.Tensor: + inter, union = _box_inter_union(pred_boxes, gt_boxes) + union[:, iscrowd] = box_area(pred_boxes).reshape(-1, 1) + iou = inter / union + iou[iou.isnan()] = 0 + return iou + + self.box_iou = box_iou + except ImportError: + raise ModuleNotFoundError("This metric requires torchvision to be installed.") + + if iou_thresholds is None: + iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double) + + self._iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds") + + if rec_thresholds is None: + rec_thresholds = torch.linspace(0, 1, 101, dtype=torch.double) + + self._num_classes = num_classes + self._area_range = area_range + self._max_detections_per_image_per_class = max_detections_per_image_per_class + + super(ObjectDetectionAvgPrecisionRecall, self).__init__( + output_transform=output_transform, + device=device, + skip_unrolling=skip_unrolling, + ) + super(Metric, self).__init__( + rec_thresholds=rec_thresholds, + class_mean=None, + ) + precision = torch.double if torch.device(device).type != "mps" else torch.float32 + self.rec_thresholds = cast(torch.Tensor, self.rec_thresholds).to(device=device, dtype=precision) + + @reinit__is_reduced + def reset(self) -> None: + self._tps = [] + self._fps = [] + self._scores = [] + self._y_pred_labels = [] + self._y_true_count = torch.zeros((self._num_classes,), device=self._device) + + def _match_area_range(self, bboxes: torch.Tensor) -> torch.Tensor: + from torchvision.ops.boxes import box_area + + areas = box_area(bboxes) + if self._area_range == "all": + min_area = 0 + max_area = 1e10 + elif self._area_range == "small": + min_area = 0 + max_area = 1024 + elif self._area_range == "medium": + min_area = 1024 + max_area = 9216 + elif self._area_range == "large": + min_area = 9216 + max_area = 1e10 + return torch.logical_and(areas >= min_area, areas <= max_area) + + def _check_matching_input( + self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]] + ) -> None: + y_pred, y = output + if len(y_pred) != len(y): + raise ValueError(f"y_pred and y should have the same number of samples, given {len(y_pred)} and {len(y)}.") + if len(y_pred) == 0: + raise ValueError("y_pred and y should contain at least one sample.") + + y_pred_keys = {"bbox", "scores", "labels"} + if (y_pred[0].keys() & y_pred_keys) != y_pred_keys: + raise ValueError( + "y_pred sample dictionaries should have 'bbox', 'scores'" + f" and 'labels' keys, given keys: {y_pred[0].keys()}" + ) + + y_keys = {"bbox", "labels"} + if (y[0].keys() & y_keys) != y_keys: + raise ValueError( + "y sample dictionaries should have 'bbox', 'labels'" + f" and optionally 'iscrowd' keys, given keys: {y[0].keys()}" + ) + + def _compute_recall_and_precision( + self, TP: torch.Tensor, FP: torch.Tensor, scores: torch.Tensor, y_true_count: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Measuring recall & precision + + This method is different from that of MeanAveragePrecision since in the pycocotools reference implementation, + when there are predictions with the same scores, they're considered associated with different thresholds in the + course of measuring recall values, although it's not logically correct as those predictions are really + associated with a single threshold, thus outputing a single recall value. + + Shape of function inputs and return values follow the table below. N\ :sub:`pred` is the number of detections + or predictions. ``...`` stands for the possible additional dimensions. Finally, \#unique scores represents + number of unique scores in ``scores`` which is actually the number of thresholds. + + ============== ====================== + **Object** **Shape** + ============== ====================== + TP and FP (..., N\ :sub:`pred`) + scores (N\ :sub:`pred`,) + y_true_count () (A single float, + greater than zero) + recall (..., \#unique scores) + precision (..., \#unique scores) + ============== ====================== + + Returns: + `(recall, precision)` + """ + indices = torch.argsort(scores, dim=-1, stable=True, descending=True) + tp = TP[..., indices] + tp_summation = tp.cumsum(dim=-1) + if tp_summation.device.type != "mps": + tp_summation = tp_summation.double() + + fp = FP[..., indices] + fp_summation = fp.cumsum(dim=-1) + if fp_summation.device.type != "mps": + fp_summation = fp_summation.double() + + recall = tp_summation / y_true_count + predicted_positive = tp_summation + fp_summation + precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) + + return recall, precision + + def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tensor) -> torch.Tensor: + """Measuring average precision. + This method is overriden since :math:`1/#recall_thresholds` is used instead of :math:`r_k - r_{k-1}` + as the recall differential in COCO's reference implementation i.e., pycocotools. + + Args: + recall: n-dimensional tensor whose last dimension is the dimension of the samples. Should be ordered in + ascending order in its last dimension. + precision: like ``recall`` in the shape. + + Returns: + average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions. + """ + if precision.device.type == "mps": + # Manual fallback to CPU if precision is on MPS due to the error: + # NotImplementedError: The operator 'aten::_cummax_helper' is not currently implemented for the MPS device + device = precision.device + precision_integrand = precision.flip(-1).cpu() + precision_integrand = precision_integrand.cummax(dim=-1).values + precision_integrand = precision_integrand.to(device=device).flip(-1) + else: + precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1) + rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1)) + rec_thresh_indices = ( + torch.searchsorted(recall, rec_thresholds) + if recall.size(-1) != 0 + else torch.LongTensor([], device=self._device) + ) + precision_integrand = precision_integrand.take_along_dim( + rec_thresh_indices.where(rec_thresh_indices != recall.size(-1), 0), dim=-1 + ).where(rec_thresh_indices != recall.size(-1), 0) + return torch.sum(precision_integrand, dim=-1) / len(cast(torch.Tensor, self.rec_thresholds)) + + @reinit__is_reduced + def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]) -> None: + r"""Metric update method using prediction and target. + + Args: + output: a tuple, (y_pred, y), of two same-length lists, each one containing + str-to-tensor dictionaries whose items is as follows. N\ :sub:`det` and + N\ :sub:`gt` are number of detections and ground truths for a sample + respectively. + + ======== ================== ================================================= + **y_pred items** + ----------------------------------------------------------------------------- + Key Value shape Description + ======== ================== ================================================= + 'bbox' (N\ :sub:`det`, 4) Bounding boxes of form (x1, y1, x2, y2) + containing top left and bottom right coordinates. + 'scores' (N\ :sub:`det`,) Confidence score of detections. + 'labels' (N\ :sub:`det`,) Predicted category number of detections in + `torch.long` dtype. + ======== ================== ================================================= + + ========= ================= ================================================= + **y items** + ----------------------------------------------------------------------------- + Key Value shape Description + ========= ================= ================================================= + 'bbox' (N\ :sub:`gt`, 4) Bounding boxes of form (x1, y1, x2, y2) + containing top left and bottom right coordinates. + 'labels' (N\ :sub:`gt`,) Category number of ground truths in `torch.long` + dtype. + 'iscrowd' (N\ :sub:`gt`,) Whether ground truth boxes are crowd ones or not. + This key is optional. + ========= ================= ================================================= + """ + self._check_matching_input(output) + for pred, target in zip(*output): + labels = target["labels"] + gt_boxes = target["bbox"] + gt_is_crowd = ( + target["iscrowd"].bool() if "iscrowd" in target else torch.zeros_like(labels, dtype=torch.bool) + ) + gt_ignore = ~self._match_area_range(gt_boxes) | gt_is_crowd + self._y_true_count += torch.bincount(labels[~gt_ignore], minlength=self._num_classes).to( + device=self._device + ) + + # Matching logic of object detection mAP, according to COCO reference implementation. + if len(pred["labels"]): + best_detections_index = torch.argsort(pred["scores"], stable=True, descending=True) + max_best_detections_index = torch.cat( + [ + best_detections_index[pred["labels"][best_detections_index] == c][ + : self._max_detections_per_image_per_class + ] + for c in range(self._num_classes) + ] + ) + pred_boxes = pred["bbox"][max_best_detections_index] + pred_labels = pred["labels"][max_best_detections_index] + if not len(labels): + tp = torch.zeros( + (len(self._iou_thresholds), len(max_best_detections_index)), + dtype=torch.uint8, + device=self._device, + ) + self._tps.append(tp) + self._fps.append(~tp & self._match_area_range(pred_boxes).to(self._device)) + else: + ious = self.box_iou(pred_boxes, gt_boxes, cast(torch.BoolTensor, gt_is_crowd)) + category_no_match = labels.expand(len(pred_labels), -1) != pred_labels.view(-1, 1) + NO_MATCH = -3 + ious[category_no_match] = NO_MATCH + ious = ious.unsqueeze(-1).repeat((1, 1, len(self._iou_thresholds))) + ious[ious < self._iou_thresholds] = NO_MATCH + IGNORANCE = -2 + ious[:, gt_ignore] += IGNORANCE + for i in range(len(pred_labels)): + # Flip is done to give priority to the last item with maximal value, + # as torch.max selects the first one. + match_gts = ious[i].flip(0).max(0) + match_gts_indices = ious.size(1) - 1 - match_gts.indices + for t in range(len(self._iou_thresholds)): + if match_gts.values[t] > NO_MATCH and not gt_is_crowd[match_gts_indices[t]]: + ious[:, match_gts_indices[t], t] = NO_MATCH + ious[i, match_gts_indices[t], t] = match_gts.values[t] + + max_ious = ious.max(1).values + self._tps.append((max_ious >= 0).T.to(dtype=torch.uint8, device=self._device)) + self._fps.append( + ((max_ious <= NO_MATCH).T & self._match_area_range(pred_boxes)).to( + dtype=torch.uint8, device=self._device + ) + ) + + scores = pred["scores"][max_best_detections_index] + if self._device.type == "mps" and scores.dtype == torch.double: + scores = scores.to(dtype=torch.float32) + self._scores.append(scores.to(self._device)) + self._y_pred_labels.append(pred_labels.to(dtype=torch.int, device=self._device)) + + @sync_all_reduce("_y_true_count") + def _compute(self) -> torch.Tensor: + pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.int, self._device) + TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) + FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) + fp_precision = torch.double if self._device.type != "mps" else torch.float32 + scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), fp_precision, self._device) + + average_precisions_recalls = -torch.ones( + (2, self._num_classes, len(self._iou_thresholds)), + device=self._device, + dtype=fp_precision, + ) + for cls in range(self._num_classes): + if self._y_true_count[cls] == 0: + continue + + cls_labels = pred_labels == cls + if sum(cls_labels) == 0: + average_precisions_recalls[:, cls] = 0.0 + continue + + recall, precision = self._compute_recall_and_precision( + TP[..., cls_labels], FP[..., cls_labels], scores[cls_labels], self._y_true_count[cls] + ) + average_precision_for_cls_per_iou_threshold = self._compute_average_precision(recall, precision) + average_precisions_recalls[0, cls] = average_precision_for_cls_per_iou_threshold + average_precisions_recalls[1, cls] = recall[..., -1] + return average_precisions_recalls + + def compute(self) -> Tuple[float, float]: + average_precisions_recalls = self._compute() + if (average_precisions_recalls == -1).all(): + return -1.0, -1.0 + ap = average_precisions_recalls[0][average_precisions_recalls[0] > -1].mean().item() + ar = average_precisions_recalls[1][average_precisions_recalls[1] > -1].mean().item() + return ap, ar + + +class CommonObjectDetectionMetrics(MetricGroup): + """ + Common Object Detection metrics. Included metrics are as follows: + + =============== ========================================== + **Metric name** **Description** + =============== ========================================== + AP@50..95 Average precision averaged over + .50 to.95 IOU thresholds + AR-100 Average recall with maximum 100 detections + AP@50 Average precision with IOU threshold=.50 + AP@75 Average precision with IOU threshold=.75 + AP-S Average precision over small objects + (< 32px * 32px) + AR-S Average recall over small objects + AP-M Average precision over medium objects + (S < . < 96px * 96px) + AR-M Average recall over medium objects + AP-L Average precision over large objects + (M < . < 1e5px * 1e5px) + AR-L Average recall over large objects + greater than zero) + AR-1 Average recall with maximum 1 detection + AR-10 Average recall with maximum 10 detections + =============== ========================================== + + .. versionadded:: 0.5.2 + """ + + _state_dict_all_req_keys = ("metrics", "ap_50_95") + + ap_50_95: ObjectDetectionAvgPrecisionRecall + + def __init__( + self, + num_classes: int = 80, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = True, + ): + self.ap_50_95 = ObjectDetectionAvgPrecisionRecall(num_classes=num_classes, device=device) + + super().__init__( + { + "S": ObjectDetectionAvgPrecisionRecall(num_classes=num_classes, device=device, area_range="small"), + "M": ObjectDetectionAvgPrecisionRecall(num_classes=num_classes, device=device, area_range="medium"), + "L": ObjectDetectionAvgPrecisionRecall(num_classes=num_classes, device=device, area_range="large"), + "1": ObjectDetectionAvgPrecisionRecall( + num_classes=num_classes, device=device, max_detections_per_image_per_class=1 + ), + "10": ObjectDetectionAvgPrecisionRecall( + num_classes=num_classes, device=device, max_detections_per_image_per_class=10 + ), + }, + output_transform, + skip_unrolling=skip_unrolling, + ) + + def reset(self) -> None: + super().reset() + self.ap_50_95.reset() + + def update(self, output: Sequence[torch.Tensor]) -> None: + super().update(output) + self.ap_50_95.update(output) + + def compute(self) -> Dict[str, float]: + average_precisions_recalls = self.ap_50_95._compute() + + average_precisions_50 = average_precisions_recalls[0, :, 0] + average_precisions_75 = average_precisions_recalls[0, :, 5] + if (average_precisions_50 == -1).all(): + AP_50 = AP_75 = AP_50_95 = AR_100 = -1.0 + else: + AP_50 = average_precisions_50[average_precisions_50 > -1].mean().item() + AP_75 = average_precisions_75[average_precisions_75 > -1].mean().item() + AP_50_95 = average_precisions_recalls[0][average_precisions_recalls[0] > -1].mean().item() + AR_100 = average_precisions_recalls[1][average_precisions_recalls[1] > -1].mean().item() + + result = super().compute() + return { + "AP@50..95": AP_50_95, + "AR-100": AR_100, + "AP@50": AP_50, + "AP@75": AP_75, + "AP-S": result["S"][0], + "AR-S": result["S"][1], + "AP-M": result["M"][0], + "AR-M": result["M"][1], + "AP-L": result["L"][0], + "AR-L": result["L"][1], + "AR-1": result["1"][1], + "AR-10": result["10"][1], + } diff --git a/requirements-dev.txt b/requirements-dev.txt index d475e556cdf..a74e6e55980 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -25,6 +25,7 @@ pynvml clearml scikit-image py-rouge +pycocotools # temporary fix for python=3.12 and v3.8.1 # nltk git+https://github.com/nltk/nltk@aba99c8 diff --git a/tests/ignite/metrics/test_mean_average_precision.py b/tests/ignite/metrics/test_mean_average_precision.py new file mode 100644 index 00000000000..f24f33abb9d --- /dev/null +++ b/tests/ignite/metrics/test_mean_average_precision.py @@ -0,0 +1,205 @@ +import numpy as np +import pytest +import torch +from sklearn.metrics import average_precision_score, precision_recall_curve + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.metrics import MeanAveragePrecision +from ignite.utils import manual_seed, to_onehot + +manual_seed(41) + + +def test_wrong_input(): + with pytest.raises(ValueError, match="rec_thresholds should be a one-dimensional tensor or a sequence of floats"): + MeanAveragePrecision(rec_thresholds=torch.zeros((2, 2))) + + with pytest.raises(TypeError, match="rec_thresholds should be a sequence of floats or a tensor"): + MeanAveragePrecision(rec_thresholds={0, 0.2, 0.4, 0.6, 0.8}) + + with pytest.raises(ValueError, match="Wrong `class_mean` parameter"): + MeanAveragePrecision(class_mean="samples") + + with pytest.raises(ValueError, match="rec_thresholds values should be between 0 and 1"): + MeanAveragePrecision(rec_thresholds=(0.0, 0.5, 1.0, 1.5)) + + metric = MeanAveragePrecision() + with pytest.raises(RuntimeError, match="Metric could not be computed without any update method call"): + metric.compute() + + +def test_wrong_classification_input(): + metric = MeanAveragePrecision() + + with pytest.raises(TypeError, match="`y_pred` should be a float tensor"): + metric.update((torch.tensor([0, 1, 0]), torch.tensor([1, 0, 1]))) + + metric = MeanAveragePrecision() + with pytest.warns(RuntimeWarning, match="`y` should be of dtype long when entry type is multiclass"): + metric.update((torch.tensor([[0.5, 0.4, 0.1]]), torch.tensor([2.0]))) + + with pytest.raises(ValueError, match="y_pred contains fewer classes than y"): + metric.update((torch.tensor([[0.5, 0.4, 0.1]]), torch.tensor([3]))) + + +def test__prepare_output(): + metric = MeanAveragePrecision() + + metric._type = "binary" + scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool())) + assert scores.shape == y.shape == (1, 120) + + metric._type = "multiclass" + scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 4, (5, 3, 2)))) + assert scores.shape == (4, 30) and y.shape == (30,) + + metric._type = "multilabel" + scores, y = metric._prepare_output((torch.rand((5, 4, 3, 2)), torch.randint(0, 2, (5, 4, 3, 2)).bool())) + assert scores.shape == y.shape == (4, 30) + + +def test_update(): + metric = MeanAveragePrecision() + assert len(metric._y_pred) == len(metric._y_true) == 0 + metric.update((torch.rand((5, 4)), torch.randint(0, 2, (5, 4)).bool())) + assert len(metric._y_pred) == len(metric._y_true) == 1 + + +def test__compute_recall_and_precision(): + m = MeanAveragePrecision() + + scores = torch.rand((50,)) + y_true = torch.randint(0, 2, (50,)).bool() + precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) + P = y_true.sum(dim=-1) + ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P) + assert (ignite_recall.squeeze().flip(0).numpy() == recall[:-1]).all() + assert (ignite_precision.squeeze().flip(0).numpy() == precision[:-1]).all() + + # When there's no actual positive. Numpy expectedly raises warning. + scores = torch.rand((50,)) + y_true = torch.zeros((50,)).bool() + precision, recall, _ = precision_recall_curve(y_true.numpy(), scores.numpy()) + P = torch.tensor(0) + ignite_recall, ignite_precision = m._compute_recall_and_precision(y_true, scores, P) + assert (ignite_recall.flip(0).numpy() == recall[:-1]).all() + assert (ignite_precision.flip(0).numpy() == precision[:-1]).all() + + +def test__compute_average_precision(): + m = MeanAveragePrecision() + + # Binary data + scores = np.random.rand(50) + y_true = np.random.randint(0, 2, 50) + ap = average_precision_score(y_true, scores) + precision, recall, _ = precision_recall_curve(y_true, scores) + ignite_ap = m._compute_average_precision( + torch.from_numpy(recall[:-1]).flip(-1), torch.from_numpy(precision[:-1]).flip(-1) + ) + assert np.allclose(ignite_ap.item(), ap) + + # Multilabel data + scores = np.random.rand(50, 5) + y_true = np.random.randint(0, 2, (50, 5)) + ap = average_precision_score(y_true, scores, average=None) + ignite_ap = [] + for cls in range(scores.shape[1]): + precision, recall, _ = precision_recall_curve(y_true[:, cls], scores[:, cls]) + ignite_ap.append( + m._compute_average_precision( + torch.from_numpy(recall[:-1]).flip(-1), torch.from_numpy(precision[:-1]).flip(-1) + ).item() + ) + ignite_ap = np.array(ignite_ap) + assert np.allclose(ignite_ap, ap) + + +def test_compute_binary_data(): + m = MeanAveragePrecision() + scores = torch.rand((130,)) + y_true = torch.randint(0, 2, (130,)) + + m.update((scores[:50], y_true[:50])) + m.update((scores[50:], y_true[50:])) + ignite_map = m.compute() + + map = average_precision_score(y_true.numpy(), scores.numpy()) + + assert np.allclose(ignite_map, map) + + +@pytest.mark.parametrize("class_mean", [None, "macro", "micro", "weighted"]) +def test_compute_nonbinary_data(class_mean): + scores = torch.rand((130, 5, 2, 2)) + sklearn_scores = scores.transpose(1, -1).reshape(-1, 5).numpy() + + # Multiclass + m = MeanAveragePrecision(class_mean=class_mean) + y_true = torch.randint(0, 5, (130, 2, 2)) + m.update((scores[:50], y_true[:50])) + m.update((scores[50:], y_true[50:])) + ignite_map = m.compute().numpy() + + y_true = to_onehot(y_true, 5).transpose(1, -1).reshape(-1, 5).numpy() + sklearn_map = average_precision_score(y_true, sklearn_scores, average=class_mean) + + assert np.allclose(sklearn_map, ignite_map) + + # Multilabel + m = MeanAveragePrecision(is_multilabel=True, class_mean=class_mean) + y_true = torch.randint(0, 2, (130, 5, 2, 2)).bool() + m.update((scores[:50], y_true[:50])) + m.update((scores[50:], y_true[50:])) + ignite_map = m.compute().numpy() + + y_true = y_true.transpose(1, -1).reshape(-1, 5).numpy() + sklearn_map = average_precision_score(y_true, sklearn_scores, average=class_mean) + + assert np.allclose(sklearn_map, ignite_map) + + +@pytest.mark.parametrize("data_type", ["binary", "multiclass", "multilabel"]) +def test_distrib_integration(distributed, data_type): + rank = idist.get_rank() + world_size = idist.get_world_size() + device = idist.device() + + def _test(metric_device): + def update(_, i): + return ( + y_preds[(2 * rank + i) * 10 : (2 * rank + i + 1) * 10], + y_true[(2 * rank + i) * 10 : (2 * rank + i + 1) * 10], + ) + + engine = Engine(update) + mAP = MeanAveragePrecision(is_multilabel=data_type == "multilabel", device=metric_device) + mAP.attach(engine, "mAP") + + y_true_size = (10 * 2 * world_size, 3, 2) if data_type != "multilabel" else (10 * 2 * world_size, 4, 3, 2) + y_true = torch.randint(0, 4 if data_type == "multiclass" else 2, size=y_true_size).to(device) + y_preds_size = (10 * 2 * world_size, 4, 3, 2) if data_type != "binary" else (10 * 2 * world_size, 3, 2) + y_preds = torch.rand(y_preds_size).to(device) + + engine.run(range(2), max_epochs=1) + assert "mAP" in engine.state.metrics + + if data_type == "multiclass": + y_true = to_onehot(y_true, 4) + + if data_type == "binary": + y_true = y_true.view(-1) + y_preds = y_preds.view(-1) + else: + y_true = y_true.transpose(1, -1).reshape(-1, 4) + y_preds = y_preds.transpose(1, -1).reshape(-1, 4) + + sklearn_mAP = average_precision_score(y_true.numpy(), y_preds.numpy()) + assert np.allclose(sklearn_mAP, engine.state.metrics["mAP"]) + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(metric_device) diff --git a/tests/ignite/metrics/vision/__init__.py b/tests/ignite/metrics/vision/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/ignite/metrics/vision/test_object_detection_map.py b/tests/ignite/metrics/vision/test_object_detection_map.py new file mode 100644 index 00000000000..712b2fdebdf --- /dev/null +++ b/tests/ignite/metrics/vision/test_object_detection_map.py @@ -0,0 +1,1054 @@ +import sys +from collections import namedtuple +from math import ceil +from typing import Dict, List, Tuple +from unittest.mock import patch + +import numpy as np +from sklearn.utils.extmath import stable_cumsum + +np.float = float + +import pytest +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from torch.distributions.geometric import Geometric + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.metrics import CommonObjectDetectionMetrics, ObjectDetectionAvgPrecisionRecall +from ignite.metrics.vision.object_detection_average_precision_recall import coco_tensor_list_to_dict_list +from ignite.utils import manual_seed + +torch.set_printoptions(linewidth=200) +manual_seed(12) +np.set_printoptions(linewidth=200) + + +def coco_val2017_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]: + """ + Predictions are done using torchvision's `fasterrcnn_resnet50_fpn_v2` + with the following snippet. Note that beforehand, COCO images and annotations + were downloaded and unzipped into "val2017" and "annotations" folders respectively. + + .. code-block:: python + + import torch + from torchvision.models import detection as dtv + from torchvision.datasets import CocoDetection + + coco = CocoDetection( + "val2017", + "annotations/instances_val2017.json", + transform=dtv.FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1.transforms() + ) + model = dtv.fasterrcnn_resnet50_fpn_v2( + weights=dtv.FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1 + ) + model.eval() + + sample = torch.randint(len(coco), (10,)).tolist() + with torch.no_grad(): + pred = model([coco[s][0] for s in sample]) + + pred = [torch.cat(( + p['boxes'].int(), + p['scores'].reshape(-1, 1), + p['labels'].reshape(-1,1) + ), dim=1) for p in pred] + """ + gt = [ + torch.tensor( + [ + [418.1300, 261.8500, 511.4100, 311.0500, 23.0000, 0.0000], + [269.9200, 256.9700, 311.1200, 283.3000, 23.0000, 0.0000], + [175.1900, 252.0200, 209.2000, 278.7200, 23.0000, 0.0000], + ] + ), + torch.tensor([[196.7700, 199.6900, 301.8300, 373.2700, 2.0000, 0.0000]]), + torch.tensor( + [ + [88.4500, 168.2700, 465.9800, 318.2000, 6.0000, 0.0000], + [2.2900, 224.2100, 31.3000, 282.2200, 1.0000, 0.0000], + [143.9700, 240.8300, 176.8100, 268.6500, 84.0000, 0.0000], + [234.6200, 202.0900, 264.6500, 244.5800, 84.0000, 0.0000], + [303.0400, 193.1300, 339.9900, 242.5600, 84.0000, 0.0000], + [358.0100, 195.3100, 408.9800, 261.2500, 84.0000, 0.0000], + [197.2100, 236.3500, 233.7700, 275.1300, 84.0000, 0.0000], + ] + ), + torch.tensor( + [ + [26.1900, 44.8000, 435.5700, 268.0900, 50.0000, 0.0000], + [190.5400, 244.9400, 322.0100, 309.6500, 53.0000, 0.0000], + [72.8900, 27.2800, 591.3700, 417.8600, 51.0000, 0.0000], + [22.5400, 402.4500, 253.7400, 596.7200, 51.0000, 0.0000], + [441.9100, 168.5000, 494.3700, 213.0100, 53.0000, 0.0000], + [340.0700, 330.6800, 464.6200, 373.5800, 53.0000, 0.0000], + [28.6100, 247.9800, 591.3400, 597.6900, 67.0000, 0.0000], + [313.9300, 108.4000, 441.3300, 234.2600, 57.0000, 0.0000], + ] + ), + torch.tensor([[0.9600, 74.6700, 498.7600, 261.3400, 5.0000, 0.0000]]), + torch.tensor( + [ + [1.3800, 2.7500, 361.9800, 499.6100, 17.0000, 0.0000], + [325.4600, 239.8200, 454.2700, 416.8700, 47.0000, 0.0000], + [165.3900, 307.3400, 332.2200, 548.9500, 47.0000, 0.0000], + [424.0600, 179.6600, 480.0000, 348.0600, 47.0000, 0.0000], + ] + ), + torch.tensor( + [ + [218.2600, 84.2700, 473.6000, 230.0600, 28.0000, 0.0000], + [212.6300, 220.0000, 417.2200, 312.8100, 62.0000, 0.0000], + [104.4600, 311.2200, 222.4600, 375.0000, 62.0000, 0.0000], + ] + ), + torch.tensor( + [ + [144.4400, 132.2200, 639.7700, 253.8800, 5.0000, 0.0000], + [0.0000, 48.4300, 492.8900, 188.0200, 5.0000, 0.0000], + ] + ), + torch.tensor( + [ + [205.9900, 276.6200, 216.8500, 311.2300, 1.0000, 0.0000], + [378.9100, 65.4500, 476.1900, 185.3100, 38.0000, 0.0000], + ] + ), + torch.tensor( + [ + [392.0900, 110.3800, 417.5200, 136.7000, 85.0000, 0.0000], + [594.9000, 223.4600, 640.0000, 473.6600, 1.0000, 0.0000], + [381.1400, 225.8200, 443.3100, 253.1200, 58.0000, 0.0000], + [175.2800, 253.8000, 221.0600, 281.7400, 60.0000, 0.0000], + [96.4100, 357.6800, 168.8800, 392.0300, 60.0000, 0.0000], + [126.6500, 258.3200, 178.3800, 285.8900, 60.0000, 0.0000], + [71.0200, 247.1500, 117.2300, 275.0500, 60.0000, 0.0000], + [34.2300, 265.8500, 82.4000, 296.0700, 60.0000, 0.0000], + [32.9700, 252.5500, 69.1700, 271.6400, 60.0000, 0.0000], + [50.7800, 229.3100, 90.7400, 248.5200, 60.0000, 0.0000], + [82.8900, 218.2800, 126.5900, 237.6700, 60.0000, 0.0000], + [80.6300, 263.9200, 128.1000, 290.4100, 60.0000, 0.0000], + [277.2200, 222.9500, 323.7000, 243.6300, 60.0000, 0.0000], + [225.5200, 155.3300, 272.0700, 167.2700, 60.0000, 0.0000], + [360.6200, 178.1900, 444.4800, 207.2900, 1.0000, 0.0000], + [171.6000, 149.8100, 219.9900, 169.7400, 60.0000, 0.0000], + [398.0100, 223.2000, 461.1000, 252.6500, 58.0000, 0.0000], + [1.0000, 123.0000, 603.0000, 402.0000, 60.0000, 1.0000], + ] + ), + ] + + pred = [ + torch.tensor( + [ + [418.0000, 260.0000, 513.0000, 310.0000, 0.9975, 23.0000], + [175.0000, 252.0000, 208.0000, 280.0000, 0.9631, 20.0000], + [269.0000, 256.0000, 310.0000, 292.0000, 0.6253, 20.0000], + [269.0000, 256.0000, 311.0000, 286.0000, 0.3523, 21.0000], + [269.0000, 256.0000, 311.0000, 288.0000, 0.2267, 16.0000], + [175.0000, 252.0000, 209.0000, 280.0000, 0.1528, 23.0000], + ] + ), + torch.tensor( + [ + [197.0000, 207.0000, 299.0000, 344.0000, 0.9573, 2.0000], + [0.0000, 0.0000, 371.0000, 477.0000, 0.8672, 7.0000], + [218.0000, 250.0000, 298.0000, 369.0000, 0.6232, 2.0000], + [303.0000, 70.0000, 374.0000, 471.0000, 0.2851, 82.0000], + [200.0000, 203.0000, 282.0000, 290.0000, 0.1519, 2.0000], + [206.0000, 169.0000, 275.0000, 236.0000, 0.0581, 72.0000], + ] + ), + torch.tensor( + [ + [3.0000, 224.0000, 31.0000, 280.0000, 0.9978, 1.0000], + [86.0000, 174.0000, 457.0000, 312.0000, 0.9917, 6.0000], + [90.0000, 176.0000, 460.0000, 313.0000, 0.2907, 8.0000], + [443.0000, 198.0000, 564.0000, 255.0000, 0.2400, 3.0000], + [452.0000, 197.0000, 510.0000, 220.0000, 0.1778, 8.0000], + [462.0000, 197.0000, 513.0000, 220.0000, 0.1750, 3.0000], + [489.0000, 196.0000, 522.0000, 212.0000, 0.1480, 3.0000], + [257.0000, 165.0000, 313.0000, 182.0000, 0.1294, 8.0000], + [438.0000, 195.0000, 556.0000, 256.0000, 0.1198, 8.0000], + [555.0000, 235.0000, 575.0000, 251.0000, 0.0831, 3.0000], + [486.0000, 54.0000, 504.0000, 63.0000, 0.0634, 34.0000], + [565.0000, 245.0000, 573.0000, 251.0000, 0.0605, 3.0000], + [257.0000, 165.0000, 313.0000, 183.0000, 0.0569, 6.0000], + [0.0000, 233.0000, 42.0000, 256.0000, 0.0526, 3.0000], + [26.0000, 237.0000, 42.0000, 250.0000, 0.0515, 28.0000], + ] + ), + torch.tensor( + [ + [62.0000, 25.0000, 596.0000, 406.0000, 0.9789, 51.0000], + [24.0000, 393.0000, 250.0000, 596.0000, 0.9496, 51.0000], + [25.0000, 154.0000, 593.0000, 595.0000, 0.8013, 67.0000], + [27.0000, 112.0000, 246.0000, 263.0000, 0.7848, 50.0000], + [186.0000, 242.0000, 362.0000, 359.0000, 0.7522, 54.0000], + [448.0000, 209.0000, 512.0000, 241.0000, 0.7446, 57.0000], + [197.0000, 188.0000, 297.0000, 235.0000, 0.6616, 57.0000], + [233.0000, 210.0000, 297.0000, 229.0000, 0.5214, 57.0000], + [519.0000, 255.0000, 589.0000, 284.0000, 0.4844, 48.0000], + [316.0000, 59.0000, 360.0000, 132.0000, 0.4211, 52.0000], + [27.0000, 104.0000, 286.0000, 267.0000, 0.3727, 48.0000], + [191.0000, 89.0000, 344.0000, 256.0000, 0.3143, 57.0000], + [182.0000, 80.0000, 529.0000, 353.0000, 0.3046, 57.0000], + [417.0000, 266.0000, 488.0000, 333.0000, 0.2602, 52.0000], + [324.0000, 110.0000, 369.0000, 157.0000, 0.2550, 57.0000], + [314.0000, 61.0000, 361.0000, 135.0000, 0.2517, 53.0000], + [252.0000, 218.0000, 336.0000, 249.0000, 0.2486, 57.0000], + [191.0000, 241.0000, 342.0000, 354.0000, 0.2384, 53.0000], + [194.0000, 174.0000, 327.0000, 247.0000, 0.2121, 57.0000], + [229.0000, 200.0000, 302.0000, 233.0000, 0.2030, 57.0000], + [439.0000, 192.0000, 526.0000, 252.0000, 0.2004, 57.0000], + [203.0000, 144.0000, 523.0000, 357.0000, 0.1937, 52.0000], + [17.0000, 90.0000, 361.0000, 283.0000, 0.1875, 50.0000], + [15.0000, 14.0000, 598.0000, 272.0000, 0.1747, 67.0000], + [319.0000, 63.0000, 433.0000, 158.0000, 0.1621, 53.0000], + [319.0000, 62.0000, 434.0000, 157.0000, 0.1602, 52.0000], + [253.0000, 85.0000, 311.0000, 147.0000, 0.1562, 57.0000], + [14.0000, 25.0000, 214.0000, 211.0000, 0.1330, 67.0000], + [147.0000, 146.0000, 545.0000, 386.0000, 0.0867, 51.0000], + [324.0000, 174.0000, 455.0000, 292.0000, 0.0761, 52.0000], + [25.0000, 480.0000, 205.0000, 594.0000, 0.0727, 59.0000], + [166.0000, 0.0000, 603.0000, 32.0000, 0.0583, 84.0000], + [519.0000, 255.0000, 589.0000, 285.0000, 0.0578, 50.0000], + ] + ), + torch.tensor( + [ + [0.0000, 58.0000, 495.0000, 258.0000, 0.9917, 5.0000], + [199.0000, 291.0000, 212.0000, 299.0000, 0.5247, 37.0000], + [0.0000, 277.0000, 307.0000, 331.0000, 0.1169, 5.0000], + [0.0000, 284.0000, 302.0000, 308.0000, 0.0984, 5.0000], + [348.0000, 231.0000, 367.0000, 244.0000, 0.0621, 15.0000], + [349.0000, 229.0000, 367.0000, 244.0000, 0.0547, 8.0000], + ] + ), + torch.tensor( + [ + [1.0000, 9.0000, 365.0000, 506.0000, 0.9980, 17.0000], + [170.0000, 304.0000, 335.0000, 542.0000, 0.9867, 47.0000], + [422.0000, 179.0000, 480.0000, 351.0000, 0.9476, 47.0000], + [329.0000, 241.0000, 449.0000, 420.0000, 0.8503, 47.0000], + [0.0000, 352.0000, 141.0000, 635.0000, 0.4145, 74.0000], + [73.0000, 277.0000, 478.0000, 628.0000, 0.3859, 67.0000], + [329.0000, 183.0000, 373.0000, 286.0000, 0.3097, 47.0000], + [0.0000, 345.0000, 145.0000, 631.0000, 0.2359, 9.0000], + [1.0000, 341.0000, 147.0000, 632.0000, 0.2259, 70.0000], + [0.0000, 338.0000, 148.0000, 632.0000, 0.1669, 62.0000], + [339.0000, 154.0000, 410.0000, 248.0000, 0.1474, 47.0000], + [422.0000, 176.0000, 479.0000, 359.0000, 0.1422, 44.0000], + [0.0000, 349.0000, 148.0000, 636.0000, 0.1369, 42.0000], + [1.0000, 347.0000, 149.0000, 633.0000, 0.1118, 1.0000], + [324.0000, 238.0000, 455.0000, 423.0000, 0.0948, 86.0000], + [0.0000, 348.0000, 146.0000, 640.0000, 0.0885, 37.0000], + [0.0000, 342.0000, 140.0000, 626.0000, 0.0812, 81.0000], + [146.0000, 0.0000, 478.0000, 217.0000, 0.0812, 62.0000], + [75.0000, 102.0000, 357.0000, 553.0000, 0.0618, 64.0000], + [2.0000, 356.0000, 145.0000, 635.0000, 0.0608, 51.0000], + [0.0000, 337.0000, 149.0000, 637.0000, 0.0544, 3.0000], + ] + ), + torch.tensor( + [ + [212.0000, 219.0000, 418.0000, 312.0000, 0.9968, 62.0000], + [218.0000, 83.0000, 477.0000, 228.0000, 0.9902, 28.0000], + [113.0000, 221.0000, 476.0000, 368.0000, 0.3940, 62.0000], + [108.0000, 309.0000, 222.0000, 371.0000, 0.2972, 62.0000], + [199.0000, 124.0000, 206.0000, 130.0000, 0.2770, 16.0000], + [213.0000, 154.0000, 447.0000, 301.0000, 0.2698, 28.0000], + [122.0000, 297.0000, 492.0000, 371.0000, 0.2263, 62.0000], + [111.0000, 302.0000, 500.0000, 368.0000, 0.2115, 67.0000], + [319.0000, 220.0000, 424.0000, 307.0000, 0.1761, 62.0000], + [453.0000, 0.0000, 462.0000, 8.0000, 0.1390, 38.0000], + [107.0000, 309.0000, 222.0000, 371.0000, 0.1075, 15.0000], + [109.0000, 309.0000, 225.0000, 372.0000, 0.1028, 67.0000], + [137.0000, 301.0000, 499.0000, 371.0000, 0.0945, 61.0000], + [454.0000, 0.0000, 460.0000, 6.0000, 0.0891, 16.0000], + [162.0000, 102.0000, 167.0000, 105.0000, 0.0851, 16.0000], + [395.0000, 263.0000, 500.0000, 304.0000, 0.0813, 15.0000], + [107.0000, 298.0000, 491.0000, 373.0000, 0.0727, 9.0000], + [157.0000, 78.0000, 488.0000, 332.0000, 0.0573, 28.0000], + [110.0000, 282.0000, 500.0000, 369.0000, 0.0554, 15.0000], + [377.0000, 263.0000, 500.0000, 315.0000, 0.0527, 62.0000], + ] + ), + torch.tensor( + [ + [1.0000, 48.0000, 505.0000, 184.0000, 0.9939, 5.0000], + [152.0000, 60.0000, 633.0000, 255.0000, 0.9552, 5.0000], + [0.0000, 183.0000, 20.0000, 200.0000, 0.2347, 8.0000], + [0.0000, 185.0000, 7.0000, 202.0000, 0.1005, 8.0000], + [397.0000, 255.0000, 491.0000, 276.0000, 0.0781, 42.0000], + [0.0000, 186.0000, 7.0000, 202.0000, 0.0748, 3.0000], + [259.0000, 154.0000, 640.0000, 254.0000, 0.0630, 5.0000], + ] + ), + torch.tensor( + [ + [203.0000, 277.0000, 215.0000, 312.0000, 0.9953, 1.0000], + [380.0000, 70.0000, 475.0000, 183.0000, 0.9555, 38.0000], + [439.0000, 70.0000, 471.0000, 176.0000, 0.3617, 38.0000], + [379.0000, 143.0000, 390.0000, 158.0000, 0.2418, 38.0000], + [378.0000, 140.0000, 461.0000, 184.0000, 0.1672, 38.0000], + [226.0000, 252.0000, 230.0000, 255.0000, 0.0570, 16.0000], + ] + ), + torch.tensor( + [ + [597.0000, 216.0000, 639.0000, 475.0000, 0.9783, 1.0000], + [80.0000, 263.0000, 128.0000, 291.0000, 0.9571, 60.0000], + [126.0000, 258.0000, 178.0000, 286.0000, 0.9540, 60.0000], + [174.0000, 252.0000, 221.0000, 279.0000, 0.9434, 60.0000], + [248.0000, 323.0000, 300.0000, 354.0000, 0.9359, 60.0000], + [171.0000, 150.0000, 220.0000, 166.0000, 0.9347, 60.0000], + [121.0000, 151.0000, 173.0000, 168.0000, 0.9336, 60.0000], + [394.0000, 111.0000, 417.0000, 135.0000, 0.9256, 85.0000], + [300.0000, 327.0000, 362.0000, 358.0000, 0.9058, 60.0000], + [264.0000, 149.0000, 306.0000, 166.0000, 0.8948, 60.0000], + [306.0000, 150.0000, 350.0000, 165.0000, 0.8798, 60.0000], + [70.0000, 150.0000, 127.0000, 168.0000, 0.8697, 60.0000], + [110.0000, 138.0000, 153.0000, 156.0000, 0.8586, 60.0000], + [223.0000, 154.0000, 270.0000, 166.0000, 0.8576, 60.0000], + [541.0000, 81.0000, 602.0000, 153.0000, 0.8352, 79.0000], + [34.0000, 266.0000, 82.0000, 295.0000, 0.8326, 60.0000], + [444.0000, 302.0000, 484.0000, 325.0000, 0.7900, 60.0000], + [14.0000, 152.0000, 73.0000, 169.0000, 0.7792, 60.0000], + [115.0000, 247.0000, 157.0000, 268.0000, 0.7654, 60.0000], + [168.0000, 350.0000, 237.0000, 385.0000, 0.7241, 60.0000], + [197.0000, 319.0000, 249.0000, 351.0000, 0.7062, 60.0000], + [89.0000, 331.0000, 149.0000, 366.0000, 0.6970, 61.0000], + [66.0000, 143.0000, 109.0000, 153.0000, 0.6787, 60.0000], + [152.0000, 332.0000, 217.0000, 358.0000, 0.6739, 60.0000], + [99.0000, 355.0000, 169.0000, 395.0000, 0.6582, 60.0000], + [583.0000, 205.0000, 594.0000, 218.0000, 0.6428, 47.0000], + [498.0000, 301.0000, 528.0000, 321.0000, 0.6373, 60.0000], + [255.0000, 146.0000, 274.0000, 155.0000, 0.6366, 60.0000], + [148.0000, 231.0000, 192.0000, 250.0000, 0.5984, 60.0000], + [501.0000, 140.0000, 551.0000, 164.0000, 0.5910, 60.0000], + [156.0000, 144.0000, 193.0000, 157.0000, 0.5910, 60.0000], + [381.0000, 225.0000, 444.0000, 254.0000, 0.5737, 60.0000], + [156.0000, 243.0000, 206.0000, 264.0000, 0.5675, 60.0000], + [229.0000, 302.0000, 280.0000, 331.0000, 0.5588, 60.0000], + [492.0000, 134.0000, 516.0000, 142.0000, 0.5492, 60.0000], + [346.0000, 150.0000, 383.0000, 165.0000, 0.5481, 60.0000], + [17.0000, 143.0000, 67.0000, 154.0000, 0.5254, 60.0000], + [283.0000, 308.0000, 330.0000, 334.0000, 0.5141, 60.0000], + [421.0000, 222.0000, 489.0000, 250.0000, 0.4983, 60.0000], + [0.0000, 107.0000, 51.0000, 134.0000, 0.4978, 78.0000], + [70.0000, 248.0000, 113.0000, 270.0000, 0.4884, 60.0000], + [215.0000, 147.0000, 262.0000, 164.0000, 0.4867, 60.0000], + [293.0000, 145.0000, 315.0000, 157.0000, 0.4841, 60.0000], + [523.0000, 272.0000, 548.0000, 288.0000, 0.4728, 60.0000], + [534.0000, 152.0000, 560.0000, 164.0000, 0.4644, 60.0000], + [516.0000, 294.0000, 546.0000, 314.0000, 0.4597, 60.0000], + [352.0000, 319.0000, 395.0000, 342.0000, 0.4364, 60.0000], + [106.0000, 234.0000, 149.0000, 255.0000, 0.4317, 60.0000], + [326.0000, 136.0000, 357.0000, 147.0000, 0.4281, 60.0000], + [135.0000, 132.0000, 166.0000, 145.0000, 0.4159, 60.0000], + [63.0000, 238.0000, 104.0000, 259.0000, 0.4136, 60.0000], + [472.0000, 221.0000, 527.0000, 246.0000, 0.4090, 60.0000], + [189.0000, 137.0000, 225.0000, 154.0000, 0.4018, 60.0000], + [135.0000, 311.0000, 195.0000, 337.0000, 0.3965, 60.0000], + [9.0000, 148.0000, 68.0000, 164.0000, 0.3915, 60.0000], + [366.0000, 232.0000, 408.0000, 257.0000, 0.3858, 60.0000], + [291.0000, 243.0000, 318.0000, 266.0000, 0.3838, 60.0000], + [494.0000, 276.0000, 524.0000, 300.0000, 0.3727, 60.0000], + [97.0000, 135.0000, 122.0000, 148.0000, 0.3717, 60.0000], + [467.0000, 289.0000, 499.0000, 309.0000, 0.3710, 60.0000], + [150.0000, 134.0000, 188.0000, 146.0000, 0.3705, 60.0000], + [427.0000, 290.0000, 463.0000, 314.0000, 0.3575, 60.0000], + [38.0000, 343.0000, 101.0000, 408.0000, 0.3540, 61.0000], + [76.0000, 313.0000, 128.0000, 343.0000, 0.3429, 61.0000], + [507.0000, 146.0000, 537.0000, 163.0000, 0.3420, 60.0000], + [451.0000, 268.0000, 478.0000, 282.0000, 0.3389, 60.0000], + [545.0000, 292.0000, 578.0000, 314.0000, 0.3252, 60.0000], + [350.0000, 309.0000, 393.0000, 336.0000, 0.3246, 60.0000], + [388.0000, 307.0000, 429.0000, 337.0000, 0.3240, 60.0000], + [34.0000, 253.0000, 67.0000, 270.0000, 0.3228, 60.0000], + [402.0000, 224.0000, 462.0000, 252.0000, 0.3177, 60.0000], + [160.0000, 131.0000, 191.0000, 142.0000, 0.3104, 60.0000], + [132.0000, 310.0000, 197.0000, 340.0000, 0.2923, 61.0000], + [481.0000, 84.0000, 543.0000, 140.0000, 0.2872, 79.0000], + [13.0000, 137.0000, 62.0000, 153.0000, 0.2859, 60.0000], + [98.0000, 355.0000, 171.0000, 395.0000, 0.2843, 61.0000], + [115.0000, 149.0000, 156.0000, 160.0000, 0.2774, 60.0000], + [65.0000, 137.0000, 101.0000, 148.0000, 0.2732, 60.0000], + [314.0000, 242.0000, 341.0000, 264.0000, 0.2714, 60.0000], + [455.0000, 237.0000, 486.0000, 251.0000, 0.2630, 60.0000], + [552.0000, 146.0000, 595.0000, 164.0000, 0.2553, 60.0000], + [50.0000, 133.0000, 78.0000, 145.0000, 0.2485, 60.0000], + [544.0000, 280.0000, 570.0000, 294.0000, 0.2459, 60.0000], + [40.0000, 144.0000, 66.0000, 154.0000, 0.2453, 60.0000], + [289.0000, 254.0000, 312.0000, 268.0000, 0.2374, 60.0000], + [266.0000, 140.0000, 292.0000, 149.0000, 0.2357, 60.0000], + [504.0000, 266.0000, 525.0000, 277.0000, 0.2281, 60.0000], + [304.0000, 285.0000, 346.0000, 309.0000, 0.2256, 60.0000], + [303.0000, 222.0000, 341.0000, 238.0000, 0.2236, 60.0000], + [498.0000, 219.0000, 549.0000, 243.0000, 0.2168, 60.0000], + [89.0000, 333.0000, 144.0000, 352.0000, 0.2159, 61.0000], + [0.0000, 108.0000, 51.0000, 135.0000, 0.2076, 79.0000], + [303.0000, 220.0000, 329.0000, 231.0000, 0.2007, 60.0000], + [0.0000, 131.0000, 38.0000, 150.0000, 0.1967, 60.0000], + [364.0000, 137.0000, 401.0000, 165.0000, 0.1958, 60.0000], + [398.0000, 95.0000, 538.0000, 139.0000, 0.1868, 79.0000], + [334.0000, 243.0000, 357.0000, 263.0000, 0.1835, 60.0000], + [480.0000, 269.0000, 503.0000, 286.0000, 0.1831, 60.0000], + [184.0000, 302.0000, 229.0000, 320.0000, 0.1784, 60.0000], + [522.0000, 286.0000, 548.0000, 300.0000, 0.1752, 60.0000], + ] + ), + ] + + return [{"bbox": p[:, :4].double(), "scores": p[:, 4].double(), "labels": p[:, 5]} for p in pred], [ + {"bbox": g[:, :4].double(), "labels": g[:, 4].long(), "iscrowd": g[:, 5]} for g in gt + ] + + +def random_sample() -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]: + torch.manual_seed(12) + targets: List[torch.Tensor] = [] + preds: List[torch.Tensor] = [] + for _ in range(30): + # Generate some ground truth boxes + n_gt_box = torch.randint(50, (1,)).item() + x1 = torch.randint(641, (n_gt_box, 1)) + y1 = torch.randint(641, (n_gt_box, 1)) + w = 640 * torch.rand((n_gt_box, 1)) + h = 640 * torch.rand((n_gt_box, 1)) + x2 = (x1 + w).clip(max=640) + y2 = (y1 + h).clip(max=640) + category = torch.randint(0, 80, (n_gt_box, 1)) + iscrowd = torch.randint(2, (n_gt_box, 1)) + targets.append(torch.cat((x1, y1, x2, y2, category, iscrowd), dim=1)) + + # Remove some of gt boxes from corresponding predictions + kept_boxes = torch.randint(2, (n_gt_box,), dtype=torch.bool) + n_predicted_box = kept_boxes.sum() + x1 = x1[kept_boxes] + y1 = y1[kept_boxes] + w = w[kept_boxes] + h = h[kept_boxes] + category = category[kept_boxes] + + # Perturb gt boxes in the prediction + perturb_x1 = 640 * (torch.rand_like(x1, dtype=torch.float) - 0.5) + perturb_y1 = 640 * (torch.rand_like(y1, dtype=torch.float) - 0.5) + perturb_w = 640 * (torch.rand_like(w, dtype=torch.float) - 0.5) + perturb_h = 640 * (torch.rand_like(h, dtype=torch.float) - 0.5) + perturb_category = Geometric(0.7).sample((n_predicted_box, 1)) * (2 * torch.randint_like(category, 2) - 1) + + x1 = (x1 + perturb_x1).clip(min=0, max=640) + y1 = (y1 + perturb_y1).clip(min=0, max=640) + w = (w + perturb_w).clip(min=0, max=640) + h = (h + perturb_h).clip(min=0, max=640) + x2 = (x1 + w).clip(max=640) + y2 = (y1 + h).clip(max=640) + category = (category + perturb_category) % 80 + confidence = torch.rand_like(category, dtype=torch.double) + perturbed_gt_boxes = torch.cat((x1, y1, x2, y2, confidence, category), dim=1) + + # Generate some additional prediction boxes + n_additional_pred_boxes = torch.randint(50, (1,)).item() + x1 = torch.randint(641, (n_additional_pred_boxes, 1)) + y1 = torch.randint(641, (n_additional_pred_boxes, 1)) + w = 640 * torch.rand((n_additional_pred_boxes, 1)) + h = 640 * torch.rand((n_additional_pred_boxes, 1)) + x2 = (x1 + w).clip(max=640) + y2 = (y1 + h).clip(max=640) + category = torch.randint(0, 80, (n_additional_pred_boxes, 1)) + confidence = torch.rand_like(category, dtype=torch.double) + additional_pred_boxes = torch.cat((x1, y1, x2, y2, confidence, category), dim=1) + + preds.append(torch.cat((perturbed_gt_boxes, additional_pred_boxes), dim=0)) + + return [{"bbox": p[:, :4], "scores": p[:, 4], "labels": p[:, 5]} for p in preds], [ + {"bbox": g[:, :4], "labels": g[:, 4].long(), "iscrowd": g[:, 5]} for g in targets + ] + + +def create_coco_api( + predictions: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]] +) -> Tuple[COCO, COCO]: + """Create COCO objects from predictions and targets + + Args: + predictions: list of predictions. Each one is a dict containing "bbox", "scores" and "labels" as its keys. The + associated value to "bbox" is a tensor of shape (n, 4) where n stands for the number of detections. + 4 represents top left and bottom right coordinates of a box in the form (x1, y1, x2, y2). The associated + values to "scores" and "labels" are tensors of shape (n,). + targets: list of targets. Each one is a dict containing "bbox", "labels" and "iscrowd" as its keys. The + associated values to "bbox" and "labels" are the same as those of the ``predictions``. The associated + value to "iscrowd" is a tensor of shape (n,) which determines if ground truth boxes are crowd or not. + """ + ann_id = 1 + coco_gt = COCO() + dataset = {"images": [], "categories": [], "annotations": []} + + for idx, target in enumerate(targets): + dataset["images"].append({"id": idx}) + bboxes = target["bbox"].clone() + bboxes[:, 2:4] -= bboxes[:, 0:2] + for i in range(bboxes.shape[0]): + bbox = bboxes[i].tolist() + area = bbox[2] * bbox[3] + ann = { + "image_id": idx, + "bbox": bbox, + "category_id": target["labels"][i].item(), + "area": area, + "iscrowd": target["iscrowd"][i].item(), + "id": ann_id, + } + dataset["annotations"].append(ann) + ann_id += 1 + dataset["categories"] = [{"id": i} for i in range(0, 91)] + coco_gt.dataset = dataset + coco_gt.createIndex() + + prediction_tensors = [] + for idx, prediction in enumerate(predictions): + bboxes = prediction["bbox"].clone() + bboxes[:, 2:4] -= bboxes[:, 0:2] + prediction_tensors.append( + torch.cat( + [ + torch.tensor(idx).repeat(bboxes.shape[0], 1), + bboxes, + prediction["scores"].unsqueeze(1), + prediction["labels"].unsqueeze(1), + ], + dim=1, + ) + ) + predictions = torch.cat(prediction_tensors, dim=0) + coco_dt = coco_gt.loadRes(predictions.numpy()) + return coco_dt, coco_gt + + +def pycoco_mAP(predictions: List[Dict[str, torch.Tensor]], targets: List[Dict[str, torch.Tensor]]) -> np.array: + """ + Returned values are AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L + """ + coco_dt, coco_gt = create_coco_api(predictions, targets) + eval = COCOeval(coco_gt, coco_dt, iouType="bbox") + eval.evaluate() + eval.accumulate() + eval.summarize() + return eval.stats + + +Sample = namedtuple("Sample", ["data", "mAP", "length"]) + + +@pytest.fixture( + params=[ + ("coco2017", "full"), + ("coco2017", "with_an_empty_pred"), + ("coco2017", "with_an_empty_gt"), + ("coco2017", "with_an_empty_pred_and_gt"), + ("random", "full"), + ("random", "with_an_empty_pred"), + ("random", "with_an_empty_gt"), + ("random", "with_an_empty_pred_and_gt"), + ] +) +def sample(request) -> Sample: + data = coco_val2017_sample() if request.param[0] == "coco2017" else random_sample() + if request.param[1] == "with_an_empty_pred": + data[0][1] = { + "bbox": torch.zeros(0, 4), + "scores": torch.zeros( + 0, + ), + "labels": torch.zeros(0, dtype=torch.long), + } + elif request.param[1] == "with_an_empty_gt": + data[1][0] = { + "bbox": torch.zeros(0, 4), + "labels": torch.zeros(0, dtype=torch.long), + "iscrowd": torch.zeros( + 0, + ), + } + elif request.param[1] == "with_an_empty_pred_and_gt": + data[0][0] = { + "bbox": torch.zeros(0, 4), + "scores": torch.zeros( + 0, + ), + "labels": torch.zeros(0, dtype=torch.long), + } + data[0][1] = { + "bbox": torch.zeros(0, 4), + "scores": torch.zeros( + 0, + ), + "labels": torch.zeros(0, dtype=torch.long), + } + data[1][0] = { + "bbox": torch.zeros(0, 4), + "labels": torch.zeros(0, dtype=torch.long), + "iscrowd": torch.zeros( + 0, + ), + } + data[1][2] = { + "bbox": torch.zeros(0, 4), + "labels": torch.zeros(0, dtype=torch.long), + "iscrowd": torch.zeros( + 0, + ), + } + mAP = pycoco_mAP(*data) + + return Sample(data, mAP, len(data[0])) + + +def test_wrong_input(): + m = ObjectDetectionAvgPrecisionRecall() + + with pytest.raises(ValueError, match="y_pred and y should have the same number of samples"): + m.update(([{"bbox": None, "scores": None}], [])) + with pytest.raises(ValueError, match="y_pred and y should contain at least one sample."): + m.update(([], [])) + with pytest.raises(ValueError, match="y_pred sample dictionaries should have 'bbox', 'scores'"): + m.update(([{"bbox": None, "scores": None}], [{"bbox": None, "labels": None}])) + with pytest.raises(ValueError, match="y sample dictionaries should have 'bbox', 'labels'"): + m.update(([{"bbox": None, "scores": None, "labels": None}], [{"labels": None}])) + + +def test_empty_data(): + """ + Note that PyCOCO returns -1 when threre's no ground truth data. + """ + + metric = ObjectDetectionAvgPrecisionRecall() + metric.update( + ( + [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}], + [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}], + ) + ) + assert len(metric._tps) == 0 + assert len(metric._fps) == 0 + assert metric.compute() == (-1, -1) + metric.update( + ( + [ + { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "scores": torch.ones((1,)), + "labels": torch.ones((1,), dtype=torch.long), + } + ], + [{"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}], + ) + ) + assert metric.compute() == (-1, -1) + + metric = ObjectDetectionAvgPrecisionRecall() + metric.update( + ( + [{"bbox": torch.zeros((0, 4)), "scores": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)}], + [ + { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "iscrowd": torch.zeros((1,)), + "labels": torch.ones((1,), dtype=torch.long), + } + ], + ) + ) + assert len(metric._tps) == 0 + assert len(metric._fps) == 0 + assert metric._y_true_count[1] == 1 + assert metric.compute() == (0, 0) + + metric = ObjectDetectionAvgPrecisionRecall() + pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.9]), + "labels": torch.tensor([5]), + } + target = {"bbox": torch.zeros((0, 4)), "iscrowd": torch.zeros((0,)), "labels": torch.zeros((0), dtype=torch.long)} + metric.update(([pred], [target])) + assert len(metric._tps) == len(metric._fps) == 1 + pycoco_result = pycoco_mAP([pred], [target]) + assert metric.compute() == (pycoco_result[0], pycoco_result[8]) + + +def test_no_torchvision(): + with patch.dict(sys.modules, {"torchvision.ops.boxes": None}): + with pytest.raises(ModuleNotFoundError, match=r"This metric requires torchvision to be installed."): + ObjectDetectionAvgPrecisionRecall() + + +def test_iou(sample): + m = ObjectDetectionAvgPrecisionRecall(num_classes=91) + from pycocotools.mask import iou as pycoco_iou + + for pred, tgt in zip(*sample.data): + pred_bbox = pred["bbox"].double() + tgt_bbox = tgt["bbox"].double() + if not pred_bbox.shape[0] or not tgt_bbox.shape[0]: + continue + iscrowd = tgt["iscrowd"] + + ignite_iou = m.box_iou(pred_bbox, tgt_bbox, iscrowd.bool()) + + pred_bbox[:, 2:4] -= pred_bbox[:, :2] + tgt_bbox[:, 2:4] -= tgt_bbox[:, :2] + pyc_iou = pycoco_iou(pred_bbox.numpy(), tgt_bbox.numpy(), iscrowd.int()) + + equal = ignite_iou.numpy() == pyc_iou + assert equal.all() + + +def test_iou_thresholding(): + metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.0, 0.3, 0.5, 0.75]) + + pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.8]), + "labels": torch.tensor([1]), + } + gt = {"bbox": torch.tensor([[0.0, 0.0, 50.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])} + metric.update(([pred], [gt])) + assert (metric._tps[0] == torch.tensor([[True], [True], [True], [False]])).all() + + pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.8]), + "labels": torch.tensor([1]), + } + gt = {"bbox": torch.tensor([[100.0, 0.0, 200.0, 100.0]]), "iscrowd": torch.zeros((1,)), "labels": torch.tensor([1])} + metric.update(([pred], [gt])) + assert (metric._tps[1] == torch.tensor([[True], [False], [False], [False]])).all() + + +def test_matching(): + """ + PyCOCO matching rules: + 1. The higher confidence in a prediction, the sooner decision is made for it. + If there's equal confidence in two predictions, the dicision is first made + for the one who comes earlier. + 2. Each ground truth box is matched with at most one prediction. Crowd ground + truth is the exception. + 3. If a ground truth is crowd or out of area range, is set to be ignored. + 4. A prediction matched with a ignored gt would get ignored, in the sense that it becomes + neither tp nor fp. + 5. An unmatched prediction would get ignored if it's out of area range. So doesn't become fp due to rule 4. + 6. Among many plausible ground truth boxes, a prediction is matched with the + one which has the highest mutual IOU. If two ground truth boxes have the + same IOU with a prediction, the later one is matched. + 7. Non-ignored ground truths are given priority over the ignored ones when matching with a prediction + even if their IOU is lower. + """ + metric = ObjectDetectionAvgPrecisionRecall(iou_thresholds=[0.2]) + + pred = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0], [0.0, 0.0, 100.0, 100.0]]), + "scores": torch.tensor([0.8, 0.9]), + "labels": torch.tensor([1, 1]), + } + gt = { + "bbox": torch.tensor([[0.0, 0.0, 100.0, 100.0]]), + "iscrowd": torch.zeros((1,)), + "labels": torch.tensor([1]), + } + metric.update(([pred], [gt])) + # Preds are sorted by their scores internally + assert (metric._tps[0] == torch.tensor([[True, False]])).all() + assert (metric._fps[0] == torch.tensor([[False, True]])).all() + assert (metric._scores[0] == torch.tensor([[0.9, 0.8]])).all() + + pred["scores"] = torch.tensor([0.9, 0.9]) + metric.update(([pred], [gt])) + assert (metric._tps[1] == torch.tensor([[True, False]])).all() + assert (metric._fps[1] == torch.tensor([[False, True]])).all() + + gt["iscrowd"] = torch.tensor([1]) + metric.update(([pred], [gt])) + assert (metric._tps[2] == torch.tensor([[False, False]])).all() + assert (metric._fps[2] == torch.tensor([[False, False]])).all() + + pred["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 100.0], [100.0, 0.0, 200.0, 100.0]]) + gt["bbox"] = torch.tensor([[0.0, 0.0, 25.0, 50.0], [50.0, 0.0, 150.0, 100.0]]) + gt["iscrowd"] = torch.zeros((2,)) + gt["labels"] = torch.tensor([1, 1]) + metric.update(([pred], [gt])) + assert (metric._tps[3] == torch.tensor([[True, False]])).all() + assert (metric._fps[3] == torch.tensor([[False, True]])).all() + + metric._area_range = "small" + pred["bbox"] = torch.tensor( + [[0.0, 0.0, 100.0, 10.0], [0.0, 0.0, 100.0, 10.0], [0.0, 0.0, 100.0, 11.0], [0.0, 0.0, 100.0, 10.0]] + ) + pred["scores"] = torch.tensor([0.9, 0.9, 0.9, 0.9]) + pred["labels"] = torch.tensor([1, 1, 1, 1]) + gt["bbox"] = torch.tensor([[0.0, 0.0, 100.0, 11.0], [0.0, 0.0, 100.0, 5.0]]) + metric.update(([pred], [gt])) + assert (metric._tps[4] == torch.tensor([[True, False, False, False]])).all() + assert (metric._fps[4] == torch.tensor([[False, False, False, True]])).all() + + pred["scores"] = torch.tensor([0.9, 1.0, 0.9, 0.9]) + metric._max_detections_per_image_per_class = 1 + metric.update(([pred], [gt])) + assert (metric._tps[5] == torch.tensor([[True]])).all() + assert (metric._fps[5] == torch.tensor([[False]])).all() + + +def sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold(y_true, y_score): + y_true = y_true == 1 + + desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] + y_true = y_true[desc_score_indices] + weight = 1.0 + + tps = stable_cumsum(y_true * weight) + fps = stable_cumsum((1 - y_true) * weight) + ps = tps + fps + precision = np.zeros_like(tps) + np.divide(tps, ps, out=precision, where=(ps != 0)) + if tps[-1] == 0: + recall = np.ones_like(tps) + else: + recall = tps / tps[-1] + + sl = slice(None, None, -1) + return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), None + + +def test__compute_recall_and_precision(): + # The case in which detector detects all gt objects but also produces some wrong predictions. + scores = torch.rand((50,)) + y_true = torch.randint(0, 2, (50,)) + m = ObjectDetectionAvgPrecisionRecall() + + ignite_recall, ignite_precision = m._compute_recall_and_precision( + y_true.bool(), ~(y_true.bool()), scores, y_true.sum() + ) + sklearn_precision, sklearn_recall, _ = sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold( + y_true.numpy(), scores.numpy() + ) + assert (ignite_recall.flip(0).numpy() == sklearn_recall[:-1]).all() + assert (ignite_precision.flip(0).numpy() == sklearn_precision[:-1]).all() + + # Like above but with two additional mean dimensions. + scores = torch.rand((50,)) + y_true = torch.zeros((6, 8, 50)) + sklearn_precisions, sklearn_recalls = [], [] + for i in range(6): + for j in range(8): + y_true[i, j, np.random.choice(50, size=15, replace=False)] = 1 + precision, recall, _ = sklearn_precision_recall_curve_allowing_multiple_recalls_at_single_threshold( + y_true[i, j].numpy(), scores.numpy() + ) + sklearn_precisions.append(precision[:-1]) + sklearn_recalls.append(recall[:-1]) + sklearn_precisions = np.array(sklearn_precisions).reshape(6, 8, -1) + sklearn_recalls = np.array(sklearn_recalls).reshape(6, 8, -1) + ignite_recall, ignite_precision = m._compute_recall_and_precision( + y_true.bool(), ~(y_true.bool()), scores, torch.tensor(15) + ) + assert (ignite_recall.flip(-1).numpy() == sklearn_recalls).all() + assert (ignite_precision.flip(-1).numpy() == sklearn_precisions).all() + + +def test_compute(sample): + device = idist.device() + + if device == torch.device("mps"): + pytest.skip("Due to MPS backend out of memory") + + # AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L + ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device) + ap_50 = ObjectDetectionAvgPrecisionRecall(num_classes=91, iou_thresholds=[0.5], device=device) + ap_75 = ObjectDetectionAvgPrecisionRecall(num_classes=91, iou_thresholds=[0.75], device=device) + ap_ar_S = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, area_range="small") + ap_ar_M = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, area_range="medium") + ap_ar_L = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, area_range="large") + ar_1 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, max_detections_per_image_per_class=1) + ar_10 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=device, max_detections_per_image_per_class=10) + + metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] + for metric in metrics: + metric.update(sample.data) + + ignite_res = [metric.compute() for metric in metrics] + ignite_res_recompute = [metric.compute() for metric in metrics] + + assert all([r1 == r2 for r1, r2 in zip(ignite_res, ignite_res_recompute)]) + + AP_50_95, AR_100 = ignite_res[0] + AP_50 = ignite_res[1][0] + AP_75 = ignite_res[2][0] + AP_S, AR_S = ignite_res[3] + AP_M, AR_M = ignite_res[4] + AP_L, AR_L = ignite_res[5] + AR_1 = ignite_res[6][1] + AR_10 = ignite_res[7][1] + all_res = [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L] + print(all_res) + assert np.allclose(all_res, sample.mAP) + + common_metrics = CommonObjectDetectionMetrics(num_classes=91, device=device) + common_metrics.update(sample.data) + res = common_metrics.compute() + common_metrics_res = [ + res["AP@50..95"], + res["AP@50"], + res["AP@75"], + res["AP-S"], + res["AP-M"], + res["AP-L"], + res["AR-1"], + res["AR-10"], + res["AR-100"], + res["AR-S"], + res["AR-M"], + res["AR-L"], + ] + print(common_metrics_res) + assert all_res == common_metrics_res + assert np.allclose(common_metrics_res, sample.mAP) + + +def test_integration(sample): + bs = 3 + + device = idist.device() + if device == torch.device("mps"): + pytest.skip("Due to MPS backend out of memory") + + def update(engine, i): + b = slice(i * bs, (i + 1) * bs) + return sample.data[0][b], sample.data[1][b] + + engine = Engine(update) + + metric_device = "cpu" if device.type == "xla" else device + metric_50_95 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device) + metric_50_95.attach(engine, name="mAP[50-95]") + + n_iter = ceil(sample.length / bs) + engine.run(range(n_iter), max_epochs=1) + + res_50_95 = engine.state.metrics["mAP[50-95]"][0] + pycoco_res_50_95 = sample.mAP[0] + + assert np.allclose(res_50_95, pycoco_res_50_95) + + +def test_tensor_list_to_dict_list(): + y_preds = [ + [torch.randn((2, 6)), torch.randn((0, 6)), torch.randn((5, 6))], + [ + {"bbox": torch.randn((2, 4)), "confidence": torch.randn((2,)), "class": torch.randn((2,))}, + {"bbox": torch.randn((5, 4)), "confidence": torch.randn((5,)), "class": torch.randn((5,))}, + ], + ] + ys = [ + [torch.randn((2, 6)), torch.randn((0, 6)), torch.randn((5, 6))], + [ + {"bbox": torch.randn((2, 4)), "class": torch.randn((2,))}, + {"bbox": torch.randn((5, 4)), "class": torch.randn((5,))}, + ], + ] + for y_pred in y_preds: + for y in ys: + y_pred_new, y_new = coco_tensor_list_to_dict_list((y_pred, y)) + if isinstance(y_pred[0], dict): + assert y_pred_new is y_pred + else: + assert all( + [ + (ypn["bbox"] == yp[:, :4]).all() + & (ypn["confidence"] == yp[:, 4]).all() + & (ypn["class"] == yp[:, 5]).all() + for yp, ypn in zip(y_pred, y_pred_new) + ] + ) + if isinstance(y[0], dict): + assert y_new is y + else: + assert all( + [ + (ytn["bbox"] == yt[:, :4]).all() + & (ytn["class"] == yt[:, 4]).all() + & (ytn["iscrowd"] == yt[:, 5]).all() + for yt, ytn in zip(y, y_new) + ] + ) + + +def test_distrib_update_compute(distributed, sample): + rank_samples_cnt = ceil(sample.length / idist.get_world_size()) + rank = idist.get_rank() + rank_samples_range = slice(rank_samples_cnt * rank, rank_samples_cnt * (rank + 1)) + + device = idist.device() + + if device == torch.device("mps"): + pytest.skip("Due to MPS backend out of memory") + + metric_device = "cpu" if device.type == "xla" else device + # AP@.5...95, AP@.5, AP@.75, AP-S, AP-M, AP-L, AR-1, AR-10, AR-100, AR-S, AR-M, AR-L + ap_50_95_ar_100 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device) + ap_50 = ObjectDetectionAvgPrecisionRecall(num_classes=91, iou_thresholds=[0.5], device=metric_device) + ap_75 = ObjectDetectionAvgPrecisionRecall(num_classes=91, iou_thresholds=[0.75], device=metric_device) + ap_ar_S = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device, area_range="small") + ap_ar_M = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device, area_range="medium") + ap_ar_L = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device, area_range="large") + ar_1 = ObjectDetectionAvgPrecisionRecall(num_classes=91, device=metric_device, max_detections_per_image_per_class=1) + ar_10 = ObjectDetectionAvgPrecisionRecall( + num_classes=91, device=metric_device, max_detections_per_image_per_class=10 + ) + + metrics = [ap_50_95_ar_100, ap_50, ap_75, ap_ar_S, ap_ar_M, ap_ar_L, ar_1, ar_10] + + y_pred_rank = sample.data[0][rank_samples_range] + y_rank = sample.data[1][rank_samples_range] + for metric in metrics: + metric.update((y_pred_rank, y_rank)) + + ignite_res = [metric.compute() for metric in metrics] + ignite_res_recompute = [metric.compute() for metric in metrics] + assert all([r1 == r2 for r1, r2 in zip(ignite_res, ignite_res_recompute)]) + + AP_50_95, AR_100 = ignite_res[0] + AP_50 = ignite_res[1][0] + AP_75 = ignite_res[2][0] + AP_S, AR_S = ignite_res[3] + AP_M, AR_M = ignite_res[4] + AP_L, AR_L = ignite_res[5] + AR_1 = ignite_res[6][1] + AR_10 = ignite_res[7][1] + all_res = [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L] + assert np.allclose(all_res, sample.mAP) + + common_metrics = CommonObjectDetectionMetrics(num_classes=91, device=device) + common_metrics.update((y_pred_rank, y_rank)) + res = common_metrics.compute() + common_metrics_res = [ + res["AP@50..95"], + res["AP@50"], + res["AP@75"], + res["AP-S"], + res["AP-M"], + res["AP-L"], + res["AR-1"], + res["AR-10"], + res["AR-100"], + res["AR-S"], + res["AR-M"], + res["AR-L"], + ] + assert all_res == common_metrics_res + assert np.allclose(common_metrics_res, sample.mAP)