From 51edf4a69ba5c8af8374db971cc7c7c80b1cf46b Mon Sep 17 00:00:00 2001 From: Aditya Oke <47158509+oke-aditya@users.noreply.github.com> Date: Mon, 20 Dec 2021 14:56:46 +0530 Subject: [PATCH] Add RetinaNet Object detection with Backbones (#529) * refactor frcnn * start adding retina * complete Retinanet * update requirments * Apply suggestions from code review * Set pytorch-lightning>=1.4.0 * Use LightningCLI Co-authored-by: Akihiro Nitta Co-authored-by: Jirka Borovec Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Co-authored-by: praecipue <43908213+praecipue@users.noreply.github.com> --- CHANGELOG.md | 3 + .../deprecated/models/object_detection.rst | 8 ++ pl_bolts/models/detection/__init__.py | 2 + .../faster_rcnn/faster_rcnn_module.py | 2 +- .../models/detection/retinanet/__init__.py | 4 + .../models/detection/retinanet/backbones.py | 37 +++++ .../detection/retinanet/retinanet_module.py | 131 ++++++++++++++++++ requirements.txt | 4 +- requirements/models.txt | 2 +- requirements/test.txt | 1 + tests/models/test_detection.py | 38 ++++- tests/models/test_scripts.py | 14 ++ 12 files changed, 238 insertions(+), 8 deletions(-) create mode 100644 pl_bolts/models/detection/retinanet/__init__.py create mode 100644 pl_bolts/models/detection/retinanet/backbones.py create mode 100644 pl_bolts/models/detection/retinanet/retinanet_module.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b76e25bce..6f7401ee5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added nn.Module support for FasterRCNN backbone ([#661](https://github.com/PyTorchLightning/lightning-bolts/pull/661)) +- Added `RetinaNet` with torchvision backbones ([#529](https://github.com/PyTorchLightning/lightning-bolts/pull/529)) + + - Added Python 3.9 support ([#786](https://github.com/PyTorchLightning/lightning-bolts/pull/786)) diff --git a/docs/source/deprecated/models/object_detection.rst b/docs/source/deprecated/models/object_detection.rst index 37ace12a88..3980b25637 100644 --- a/docs/source/deprecated/models/object_detection.rst +++ b/docs/source/deprecated/models/object_detection.rst @@ -13,6 +13,14 @@ Faster R-CNN ------------- +RetinaNet +--------- + +.. autoclass:: pl_bolts.models.detection.retinanet.retinanet_module.RetinaNet + :noindex: + +------------- + YOLO ---- diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index 2d7a4a2d95..db5525adbc 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,5 +1,6 @@ from pl_bolts.models.detection import components from pl_bolts.models.detection.faster_rcnn import FasterRCNN +from pl_bolts.models.detection.retinanet import RetinaNet from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration from pl_bolts.models.detection.yolo.yolo_module import YOLO @@ -8,4 +9,5 @@ "FasterRCNN", "YOLOConfiguration", "YOLO", + "RetinaNet", ] diff --git a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py index d9e6fdec41..4231634832 100644 --- a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py +++ b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py @@ -43,7 +43,7 @@ class FasterRCNN(LightningModule): CLI command:: # PascalVOC - python faster_rcnn.py --gpus 1 --pretrained True + python faster_rcnn_module.py --gpus 1 --pretrained True """ def __init__( diff --git a/pl_bolts/models/detection/retinanet/__init__.py b/pl_bolts/models/detection/retinanet/__init__.py new file mode 100644 index 0000000000..031049e90e --- /dev/null +++ b/pl_bolts/models/detection/retinanet/__init__.py @@ -0,0 +1,4 @@ +from pl_bolts.models.detection.retinanet.backbones import create_retinanet_backbone +from pl_bolts.models.detection.retinanet.retinanet_module import RetinaNet + +__all__ = ["create_retinanet_backbone", "RetinaNet"] diff --git a/pl_bolts/models/detection/retinanet/backbones.py b/pl_bolts/models/detection/retinanet/backbones.py new file mode 100644 index 0000000000..c039ea6ac3 --- /dev/null +++ b/pl_bolts/models/detection/retinanet/backbones.py @@ -0,0 +1,37 @@ +from typing import Any, Optional + +import torch.nn as nn + +from pl_bolts.models.detection.components import create_torchvision_backbone +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +else: # pragma: no cover + warn_missing_pkg("torchvision") + + +def create_retinanet_backbone( + backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any +) -> nn.Module: + """ + Args: + backbone: + Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", + "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", + as resnets with fpn backbones. + Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", + "resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19", + fpn: If True then constructs fpn as well. + pretrained: If None creates imagenet weights backbone. + trainable_backbone_layers: number of trainable resnet layers starting from final block. + """ + + if fpn: + # Creates a torchvision resnet model with fpn added. + backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs) + else: + # This does not create fpn backbone, it is supported for all models + backbone, _ = create_torchvision_backbone(backbone, pretrained) + return backbone diff --git a/pl_bolts/models/detection/retinanet/retinanet_module.py b/pl_bolts/models/detection/retinanet/retinanet_module.py new file mode 100644 index 0000000000..227575fb6c --- /dev/null +++ b/pl_bolts/models/detection/retinanet/retinanet_module.py @@ -0,0 +1,131 @@ +from typing import Any, Optional + +import torch +from pytorch_lightning import LightningModule + +from pl_bolts.models.detection.retinanet import create_retinanet_backbone +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision.models.detection.retinanet import RetinaNet as torchvision_RetinaNet + from torchvision.models.detection.retinanet import RetinaNetHead, retinanet_resnet50_fpn + from torchvision.ops import box_iou +else: # pragma: no cover + warn_missing_pkg("torchvision") + + +class RetinaNet(LightningModule): + """PyTorch Lightning implementation of RetinaNet. + + Paper: `Focal Loss for Dense Object Detection `_. + + Paper authors: Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár + + Model implemented by: + - `Aditya Oke ` + + During training, the model expects both the input tensors, as well as targets (list of dictionary), containing: + - boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format. + - labels (`Int64Tensor[N]`): the class label for each ground truh box + + CLI command:: + + # PascalVOC using LightningCLI + python retinanet_module.py --trainer.gpus 1 --model.pretrained True + """ + + def __init__( + self, + learning_rate: float = 0.0001, + num_classes: int = 91, + backbone: Optional[str] = None, + fpn: bool = True, + pretrained: bool = False, + pretrained_backbone: bool = True, + trainable_backbone_layers: int = 3, + **kwargs: Any, + ): + """ + Args: + learning_rate: the learning rate + num_classes: number of detection classes (including background) + backbone: Pretained backbone CNN architecture. + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block + """ + super().__init__() + self.learning_rate = learning_rate + self.num_classes = num_classes + self.backbone = backbone + if backbone is None: + self.model = retinanet_resnet50_fpn(pretrained=pretrained, **kwargs) + + self.model.head = RetinaNetHead( + in_channels=self.model.backbone.out_channels, + num_anchors=self.model.head.classification_head.num_anchors, + num_classes=num_classes, + **kwargs, + ) + + else: + backbone_model = create_retinanet_backbone( + self.backbone, fpn, pretrained_backbone, trainable_backbone_layers, **kwargs + ) + self.model = torchvision_RetinaNet(backbone_model, num_classes=num_classes, **kwargs) + + def forward(self, x): + self.model.eval() + return self.model(x) + + def training_step(self, batch, batch_idx): + + images, targets = batch + targets = [{k: v for k, v in t.items()} for t in targets] + + # fasterrcnn takes both images and targets for training, returns + loss_dict = self.model(images, targets) + loss = sum(loss for loss in loss_dict.values()) + self.log("loss", loss, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + images, targets = batch + # fasterrcnn takes only images for eval() mode + preds = self.model(images) + iou = torch.stack([self._evaluate_iou(p, t) for p, t in zip(preds, targets)]).mean() + self.log("val_iou", iou, prog_bar=True) + return {"val_iou": iou} + + def validation_epoch_end(self, outs): + avg_iou = torch.stack([o["val_iou"] for o in outs]).mean() + self.log("val_avg_iou", avg_iou) + + def _evaluate_iou(self, preds, targets): + """Evaluate intersection over union (IOU) for target from dataset and output prediction from model.""" + # no box detected, 0 IOU + if preds["boxes"].shape[0] == 0: + return torch.tensor(0.0, device=preds["boxes"].device) + return box_iou(preds["boxes"], targets["boxes"]).diag().mean() + + def configure_optimizers(self): + return torch.optim.SGD( + self.model.parameters(), + lr=self.learning_rate, + momentum=0.9, + weight_decay=0.005, + ) + + +def cli_main(): + from pytorch_lightning.utilities.cli import LightningCLI + + from pl_bolts.datamodules import VOCDetectionDataModule + + LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42) + + +if __name__ == "__main__": + cli_main() diff --git a/requirements.txt b/requirements.txt index 512ca42104..f5d3b0ad10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=1.7.* +torch>=1.7.1 torchmetrics>=0.4.1 -pytorch-lightning>=1.3.0 +pytorch-lightning>=1.4.0 packaging diff --git a/requirements/models.txt b/requirements/models.txt index fa668ce484..f7d02b71e7 100644 --- a/requirements/models.txt +++ b/requirements/models.txt @@ -1,4 +1,4 @@ -torchvision>=0.8.* +torchvision>=0.8.2 scikit-learn>=0.23 Pillow opencv-python-headless diff --git a/requirements/test.txt b/requirements/test.txt index ce49c13f57..b8f58715a7 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -14,3 +14,4 @@ atari-py==0.2.6 # needed for RL scikit-learn>=0.23 sparseml ale-py>=0.7 +jsonargparse[signatures] # for LightningCLI diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 80012e7b87..fcb14eda9b 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDetectionDataset -from pl_bolts.models.detection import YOLO, FasterRCNN, YOLOConfiguration +from pl_bolts.models.detection import YOLO, FasterRCNN, RetinaNet, YOLOConfiguration from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.models.detection.yolo.yolo_layers import _aligned_iou from tests import TEST_ROOT @@ -30,16 +30,46 @@ def test_fasterrcnn_train(tmpdir): train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) - trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer = Trainer(fast_dev_run=True, logger=False, checkpoint_callback=False, default_root_dir=tmpdir) trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl) def test_fasterrcnn_bbone_train(tmpdir): - model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=False, pretrained=False) + model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=False) train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) - trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer = Trainer(fast_dev_run=True, logger=False, checkpoint_callback=False, default_root_dir=tmpdir) + trainer.fit(model, train_dl, valid_dl) + + +@torch.no_grad() +def test_retinanet(): + model = RetinaNet(pretrained=False) + + image = torch.rand(1, 3, 400, 400) + model(image) + + +def test_retinanet_train(tmpdir): + model = RetinaNet(pretrained=False) + + train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + + trainer = Trainer(fast_dev_run=True, logger=False, checkpoint_callback=False, default_root_dir=tmpdir) + trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl) + + +def test_retinanet_backbone_train(tmpdir): + model = RetinaNet(backbone="resnet18", fpn=True, pretrained_backbone=False) + train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + + trainer = Trainer(fast_dev_run=True, logger=False, checkpoint_callback=False, default_root_dir=tmpdir) + model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=False, pretrained=False) + train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) + valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn) trainer.fit(model, train_dl, valid_dl) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index 35cb1672ec..a918a2a9a2 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -5,6 +5,9 @@ from tests import _MARK_REQUIRE_GPU, DATASETS_PATH _DEFAULT_ARGS = f" --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 2 --batch_size 4" +_DEFAULT_LIGHTNING_CLI_ARGS = ( + f" fit --data.data_dir {DATASETS_PATH} --data.batch_size 4 --trainer.max_epochs 1 --trainer.max_steps 2" +) @pytest.mark.parametrize("dataset_name", ["mnist", "cifar10"]) @@ -105,3 +108,14 @@ def test_cli_run_vision_image_gpt(cli_args): cli_args = cli_args.strip().split(" ") if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() + + +@pytest.mark.parametrize("cli_args", [_DEFAULT_LIGHTNING_CLI_ARGS + " --trainer.gpus 1"]) +@pytest.mark.skipif(**_MARK_REQUIRE_GPU) +def test_cli_run_retinanet(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.detection.retinanet.retinanet_module import cli_main + + cli_args = cli_args.strip().split(" ") if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main()