-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RetinaNet Object detection with Backbones (#529)
* refactor frcnn * start adding retina * complete Retinanet * update requirments * Apply suggestions from code review * Set pytorch-lightning>=1.4.0 * Use LightningCLI Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: praecipue <[email protected]>
- Loading branch information
1 parent
c79c0ca
commit 51edf4a
Showing
12 changed files
with
238 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from pl_bolts.models.detection.retinanet.backbones import create_retinanet_backbone | ||
from pl_bolts.models.detection.retinanet.retinanet_module import RetinaNet | ||
|
||
__all__ = ["create_retinanet_backbone", "RetinaNet"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Any, Optional | ||
|
||
import torch.nn as nn | ||
|
||
from pl_bolts.models.detection.components import create_torchvision_backbone | ||
from pl_bolts.utils import _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.warnings import warn_missing_pkg | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone | ||
else: # pragma: no cover | ||
warn_missing_pkg("torchvision") | ||
|
||
|
||
def create_retinanet_backbone( | ||
backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any | ||
) -> nn.Module: | ||
""" | ||
Args: | ||
backbone: | ||
Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", | ||
"resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", | ||
as resnets with fpn backbones. | ||
Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", | ||
"resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19", | ||
fpn: If True then constructs fpn as well. | ||
pretrained: If None creates imagenet weights backbone. | ||
trainable_backbone_layers: number of trainable resnet layers starting from final block. | ||
""" | ||
|
||
if fpn: | ||
# Creates a torchvision resnet model with fpn added. | ||
backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs) | ||
else: | ||
# This does not create fpn backbone, it is supported for all models | ||
backbone, _ = create_torchvision_backbone(backbone, pretrained) | ||
return backbone |
131 changes: 131 additions & 0 deletions
131
pl_bolts/models/detection/retinanet/retinanet_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from typing import Any, Optional | ||
|
||
import torch | ||
from pytorch_lightning import LightningModule | ||
|
||
from pl_bolts.models.detection.retinanet import create_retinanet_backbone | ||
from pl_bolts.utils import _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.warnings import warn_missing_pkg | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
from torchvision.models.detection.retinanet import RetinaNet as torchvision_RetinaNet | ||
from torchvision.models.detection.retinanet import RetinaNetHead, retinanet_resnet50_fpn | ||
from torchvision.ops import box_iou | ||
else: # pragma: no cover | ||
warn_missing_pkg("torchvision") | ||
|
||
|
||
class RetinaNet(LightningModule): | ||
"""PyTorch Lightning implementation of RetinaNet. | ||
Paper: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_. | ||
Paper authors: Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár | ||
Model implemented by: | ||
- `Aditya Oke <https://github.com/oke-aditya>` | ||
During training, the model expects both the input tensors, as well as targets (list of dictionary), containing: | ||
- boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format. | ||
- labels (`Int64Tensor[N]`): the class label for each ground truh box | ||
CLI command:: | ||
# PascalVOC using LightningCLI | ||
python retinanet_module.py --trainer.gpus 1 --model.pretrained True | ||
""" | ||
|
||
def __init__( | ||
self, | ||
learning_rate: float = 0.0001, | ||
num_classes: int = 91, | ||
backbone: Optional[str] = None, | ||
fpn: bool = True, | ||
pretrained: bool = False, | ||
pretrained_backbone: bool = True, | ||
trainable_backbone_layers: int = 3, | ||
**kwargs: Any, | ||
): | ||
""" | ||
Args: | ||
learning_rate: the learning rate | ||
num_classes: number of detection classes (including background) | ||
backbone: Pretained backbone CNN architecture. | ||
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. | ||
pretrained: if true, returns a model pre-trained on COCO train2017 | ||
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet | ||
trainable_backbone_layers: number of trainable resnet layers starting from final block | ||
""" | ||
super().__init__() | ||
self.learning_rate = learning_rate | ||
self.num_classes = num_classes | ||
self.backbone = backbone | ||
if backbone is None: | ||
self.model = retinanet_resnet50_fpn(pretrained=pretrained, **kwargs) | ||
|
||
self.model.head = RetinaNetHead( | ||
in_channels=self.model.backbone.out_channels, | ||
num_anchors=self.model.head.classification_head.num_anchors, | ||
num_classes=num_classes, | ||
**kwargs, | ||
) | ||
|
||
else: | ||
backbone_model = create_retinanet_backbone( | ||
self.backbone, fpn, pretrained_backbone, trainable_backbone_layers, **kwargs | ||
) | ||
self.model = torchvision_RetinaNet(backbone_model, num_classes=num_classes, **kwargs) | ||
|
||
def forward(self, x): | ||
self.model.eval() | ||
return self.model(x) | ||
|
||
def training_step(self, batch, batch_idx): | ||
|
||
images, targets = batch | ||
targets = [{k: v for k, v in t.items()} for t in targets] | ||
|
||
# fasterrcnn takes both images and targets for training, returns | ||
loss_dict = self.model(images, targets) | ||
loss = sum(loss for loss in loss_dict.values()) | ||
self.log("loss", loss, prog_bar=True) | ||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
images, targets = batch | ||
# fasterrcnn takes only images for eval() mode | ||
preds = self.model(images) | ||
iou = torch.stack([self._evaluate_iou(p, t) for p, t in zip(preds, targets)]).mean() | ||
self.log("val_iou", iou, prog_bar=True) | ||
return {"val_iou": iou} | ||
|
||
def validation_epoch_end(self, outs): | ||
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean() | ||
self.log("val_avg_iou", avg_iou) | ||
|
||
def _evaluate_iou(self, preds, targets): | ||
"""Evaluate intersection over union (IOU) for target from dataset and output prediction from model.""" | ||
# no box detected, 0 IOU | ||
if preds["boxes"].shape[0] == 0: | ||
return torch.tensor(0.0, device=preds["boxes"].device) | ||
return box_iou(preds["boxes"], targets["boxes"]).diag().mean() | ||
|
||
def configure_optimizers(self): | ||
return torch.optim.SGD( | ||
self.model.parameters(), | ||
lr=self.learning_rate, | ||
momentum=0.9, | ||
weight_decay=0.005, | ||
) | ||
|
||
|
||
def cli_main(): | ||
from pytorch_lightning.utilities.cli import LightningCLI | ||
|
||
from pl_bolts.datamodules import VOCDetectionDataModule | ||
|
||
LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
torch>=1.7.* | ||
torch>=1.7.1 | ||
torchmetrics>=0.4.1 | ||
pytorch-lightning>=1.3.0 | ||
pytorch-lightning>=1.4.0 | ||
packaging |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
torchvision>=0.8.* | ||
torchvision>=0.8.2 | ||
scikit-learn>=0.23 | ||
Pillow | ||
opencv-python-headless | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters