Skip to content

Commit

Permalink
Add RetinaNet Object detection with Backbones (#529)
Browse files Browse the repository at this point in the history
* refactor frcnn
* start adding retina
* complete Retinanet
* update requirments
* Apply suggestions from code review
* Set pytorch-lightning>=1.4.0
* Use LightningCLI

Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: praecipue <[email protected]>
  • Loading branch information
8 people authored Dec 20, 2021
1 parent c79c0ca commit 51edf4a
Show file tree
Hide file tree
Showing 12 changed files with 238 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added nn.Module support for FasterRCNN backbone ([#661](https://github.com/PyTorchLightning/lightning-bolts/pull/661))


- Added `RetinaNet` with torchvision backbones ([#529](https://github.com/PyTorchLightning/lightning-bolts/pull/529))


- Added Python 3.9 support ([#786](https://github.com/PyTorchLightning/lightning-bolts/pull/786))


Expand Down
8 changes: 8 additions & 0 deletions docs/source/deprecated/models/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ Faster R-CNN

-------------

RetinaNet
---------

.. autoclass:: pl_bolts.models.detection.retinanet.retinanet_module.RetinaNet
:noindex:

-------------

YOLO
----

Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pl_bolts.models.detection import components
from pl_bolts.models.detection.faster_rcnn import FasterRCNN
from pl_bolts.models.detection.retinanet import RetinaNet
from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration
from pl_bolts.models.detection.yolo.yolo_module import YOLO

Expand All @@ -8,4 +9,5 @@
"FasterRCNN",
"YOLOConfiguration",
"YOLO",
"RetinaNet",
]
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class FasterRCNN(LightningModule):
CLI command::
# PascalVOC
python faster_rcnn.py --gpus 1 --pretrained True
python faster_rcnn_module.py --gpus 1 --pretrained True
"""

def __init__(
Expand Down
4 changes: 4 additions & 0 deletions pl_bolts/models/detection/retinanet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from pl_bolts.models.detection.retinanet.backbones import create_retinanet_backbone
from pl_bolts.models.detection.retinanet.retinanet_module import RetinaNet

__all__ = ["create_retinanet_backbone", "RetinaNet"]
37 changes: 37 additions & 0 deletions pl_bolts/models/detection/retinanet/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, Optional

import torch.nn as nn

from pl_bolts.models.detection.components import create_torchvision_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
else: # pragma: no cover
warn_missing_pkg("torchvision")


def create_retinanet_backbone(
backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any
) -> nn.Module:
"""
Args:
backbone:
Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152",
"resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2",
as resnets with fpn backbones.
Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101",
"resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19",
fpn: If True then constructs fpn as well.
pretrained: If None creates imagenet weights backbone.
trainable_backbone_layers: number of trainable resnet layers starting from final block.
"""

if fpn:
# Creates a torchvision resnet model with fpn added.
backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs)
else:
# This does not create fpn backbone, it is supported for all models
backbone, _ = create_torchvision_backbone(backbone, pretrained)
return backbone
131 changes: 131 additions & 0 deletions pl_bolts/models/detection/retinanet/retinanet_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Any, Optional

import torch
from pytorch_lightning import LightningModule

from pl_bolts.models.detection.retinanet import create_retinanet_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.retinanet import RetinaNet as torchvision_RetinaNet
from torchvision.models.detection.retinanet import RetinaNetHead, retinanet_resnet50_fpn
from torchvision.ops import box_iou
else: # pragma: no cover
warn_missing_pkg("torchvision")


class RetinaNet(LightningModule):
"""PyTorch Lightning implementation of RetinaNet.
Paper: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
Paper authors: Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár
Model implemented by:
- `Aditya Oke <https://github.com/oke-aditya>`
During training, the model expects both the input tensors, as well as targets (list of dictionary), containing:
- boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format.
- labels (`Int64Tensor[N]`): the class label for each ground truh box
CLI command::
# PascalVOC using LightningCLI
python retinanet_module.py --trainer.gpus 1 --model.pretrained True
"""

def __init__(
self,
learning_rate: float = 0.0001,
num_classes: int = 91,
backbone: Optional[str] = None,
fpn: bool = True,
pretrained: bool = False,
pretrained_backbone: bool = True,
trainable_backbone_layers: int = 3,
**kwargs: Any,
):
"""
Args:
learning_rate: the learning rate
num_classes: number of detection classes (including background)
backbone: Pretained backbone CNN architecture.
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
super().__init__()
self.learning_rate = learning_rate
self.num_classes = num_classes
self.backbone = backbone
if backbone is None:
self.model = retinanet_resnet50_fpn(pretrained=pretrained, **kwargs)

self.model.head = RetinaNetHead(
in_channels=self.model.backbone.out_channels,
num_anchors=self.model.head.classification_head.num_anchors,
num_classes=num_classes,
**kwargs,
)

else:
backbone_model = create_retinanet_backbone(
self.backbone, fpn, pretrained_backbone, trainable_backbone_layers, **kwargs
)
self.model = torchvision_RetinaNet(backbone_model, num_classes=num_classes, **kwargs)

def forward(self, x):
self.model.eval()
return self.model(x)

def training_step(self, batch, batch_idx):

images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]

# fasterrcnn takes both images and targets for training, returns
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
self.log("loss", loss, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
preds = self.model(images)
iou = torch.stack([self._evaluate_iou(p, t) for p, t in zip(preds, targets)]).mean()
self.log("val_iou", iou, prog_bar=True)
return {"val_iou": iou}

def validation_epoch_end(self, outs):
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
self.log("val_avg_iou", avg_iou)

def _evaluate_iou(self, preds, targets):
"""Evaluate intersection over union (IOU) for target from dataset and output prediction from model."""
# no box detected, 0 IOU
if preds["boxes"].shape[0] == 0:
return torch.tensor(0.0, device=preds["boxes"].device)
return box_iou(preds["boxes"], targets["boxes"]).diag().mean()

def configure_optimizers(self):
return torch.optim.SGD(
self.model.parameters(),
lr=self.learning_rate,
momentum=0.9,
weight_decay=0.005,
)


def cli_main():
from pytorch_lightning.utilities.cli import LightningCLI

from pl_bolts.datamodules import VOCDetectionDataModule

LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)


if __name__ == "__main__":
cli_main()
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=1.7.*
torch>=1.7.1
torchmetrics>=0.4.1
pytorch-lightning>=1.3.0
pytorch-lightning>=1.4.0
packaging
2 changes: 1 addition & 1 deletion requirements/models.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchvision>=0.8.*
torchvision>=0.8.2
scikit-learn>=0.23
Pillow
opencv-python-headless
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ atari-py==0.2.6 # needed for RL
scikit-learn>=0.23
sparseml
ale-py>=0.7
jsonargparse[signatures] # for LightningCLI
38 changes: 34 additions & 4 deletions tests/models/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import DataLoader

from pl_bolts.datasets import DummyDetectionDataset
from pl_bolts.models.detection import YOLO, FasterRCNN, YOLOConfiguration
from pl_bolts.models.detection import YOLO, FasterRCNN, RetinaNet, YOLOConfiguration
from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
from pl_bolts.models.detection.yolo.yolo_layers import _aligned_iou
from tests import TEST_ROOT
Expand All @@ -30,16 +30,46 @@ def test_fasterrcnn_train(tmpdir):
train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer = Trainer(fast_dev_run=True, logger=False, checkpoint_callback=False, default_root_dir=tmpdir)
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl)


def test_fasterrcnn_bbone_train(tmpdir):
model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=False, pretrained=False)
model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=False)
train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer = Trainer(fast_dev_run=True, logger=False, checkpoint_callback=False, default_root_dir=tmpdir)
trainer.fit(model, train_dl, valid_dl)


@torch.no_grad()
def test_retinanet():
model = RetinaNet(pretrained=False)

image = torch.rand(1, 3, 400, 400)
model(image)


def test_retinanet_train(tmpdir):
model = RetinaNet(pretrained=False)

train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = Trainer(fast_dev_run=True, logger=False, checkpoint_callback=False, default_root_dir=tmpdir)
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl)


def test_retinanet_backbone_train(tmpdir):
model = RetinaNet(backbone="resnet18", fpn=True, pretrained_backbone=False)
train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = Trainer(fast_dev_run=True, logger=False, checkpoint_callback=False, default_root_dir=tmpdir)
model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=False, pretrained=False)
train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
trainer.fit(model, train_dl, valid_dl)


Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from tests import _MARK_REQUIRE_GPU, DATASETS_PATH

_DEFAULT_ARGS = f" --data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 2 --batch_size 4"
_DEFAULT_LIGHTNING_CLI_ARGS = (
f" fit --data.data_dir {DATASETS_PATH} --data.batch_size 4 --trainer.max_epochs 1 --trainer.max_steps 2"
)


@pytest.mark.parametrize("dataset_name", ["mnist", "cifar10"])
Expand Down Expand Up @@ -105,3 +108,14 @@ def test_cli_run_vision_image_gpt(cli_args):
cli_args = cli_args.strip().split(" ") if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()


@pytest.mark.parametrize("cli_args", [_DEFAULT_LIGHTNING_CLI_ARGS + " --trainer.gpus 1"])
@pytest.mark.skipif(**_MARK_REQUIRE_GPU)
def test_cli_run_retinanet(cli_args):
"""Test running CLI for an example with default params."""
from pl_bolts.models.detection.retinanet.retinanet_module import cli_main

cli_args = cli_args.strip().split(" ") if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

0 comments on commit 51edf4a

Please sign in to comment.