From a225a69cef3a54500fe9925818cd1b73fc34ec7e Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Thu, 8 Jun 2023 16:46:21 +0530 Subject: [PATCH 1/9] :sparkles: Add InternImage-H Results --- GETTING_STARTED.md | 20 + INSTALL.md | 10 +- README.md | 2 + ...r_intern_image_huge_bs16_160k_896x896.yaml | 39 + ...ern_image_huge_bs16_160k_896x896_1024.yaml | 53 + .../oneformer_intern_image_huge_bs16_90k.yaml | 19 + ...ormer_intern_image_huge_bs16_90k_1024.yaml | 33 + ...neformer_intern_image_huge_bs16_100ep.yaml | 24 + ...mer_intern_image_huge_bs16_100ep_1024.yaml | 38 + oneformer/config.py | 29 +- oneformer/modeling/__init__.py | 1 + oneformer/modeling/backbone/intern_image.py | 747 ++++++++++++ .../backbone/ops_dcnv3/functions/__init__.py | 7 + .../ops_dcnv3/functions/dcnv3_func.py | 188 +++ oneformer/modeling/backbone/ops_dcnv3/make.sh | 8 + .../backbone/ops_dcnv3/modules/__init__.py | 7 + .../backbone/ops_dcnv3/modules/dcnv3.py | 345 ++++++ .../modeling/backbone/ops_dcnv3/setup.py | 75 ++ .../backbone/ops_dcnv3/src/cpu/dcnv3_cpu.cpp | 37 + .../backbone/ops_dcnv3/src/cpu/dcnv3_cpu.h | 31 + .../backbone/ops_dcnv3/src/cuda/dcnv3_cuda.cu | 174 +++ .../backbone/ops_dcnv3/src/cuda/dcnv3_cuda.h | 31 + .../ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh | 1045 +++++++++++++++++ .../modeling/backbone/ops_dcnv3/src/dcnv3.h | 59 + .../backbone/ops_dcnv3/src/vision.cpp | 17 + oneformer/modeling/backbone/ops_dcnv3/test.py | 263 +++++ tools/README.md | 15 +- train_net.py | 2 + 28 files changed, 3315 insertions(+), 4 deletions(-) create mode 100644 configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml create mode 100644 configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml create mode 100644 configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k.yaml create mode 100644 configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml create mode 100644 configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml create mode 100644 configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml create mode 100644 oneformer/modeling/backbone/intern_image.py create mode 100644 oneformer/modeling/backbone/ops_dcnv3/functions/__init__.py create mode 100644 oneformer/modeling/backbone/ops_dcnv3/functions/dcnv3_func.py create mode 100755 oneformer/modeling/backbone/ops_dcnv3/make.sh create mode 100644 oneformer/modeling/backbone/ops_dcnv3/modules/__init__.py create mode 100644 oneformer/modeling/backbone/ops_dcnv3/modules/dcnv3.py create mode 100644 oneformer/modeling/backbone/ops_dcnv3/setup.py create mode 100644 oneformer/modeling/backbone/ops_dcnv3/src/cpu/dcnv3_cpu.cpp create mode 100644 oneformer/modeling/backbone/ops_dcnv3/src/cpu/dcnv3_cpu.h create mode 100644 oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_cuda.cu create mode 100644 oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_cuda.h create mode 100644 oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh create mode 100644 oneformer/modeling/backbone/ops_dcnv3/src/dcnv3.h create mode 100644 oneformer/modeling/backbone/ops_dcnv3/src/vision.cpp create mode 100644 oneformer/modeling/backbone/ops_dcnv3/test.py diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 449302e..b1dd15e 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -26,6 +26,26 @@ python train_net.py --dist-url 'tcp://127.0.0.1:50163' \ OUTPUT_DIR outputs/ade20k_swin_large WANDB.NAME ade20k_swin_large ``` +### Training on Multiple Nodes + +```bash +### Node 1 +python train_net.py --dist-url \ + --num-gpus 8 \ + --num-machines 2 \ + --machine-rank 0 \ + --config-file configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml \ + OUTPUT_DIR outputs/ade20k_intern_image_huge WANDB.NAME ade20k_intern_image_huge + +### Node 2 +python train_net.py --dist-url \ + --num-gpus 8 \ + --num-machines 2 \ + --machine-rank 1 \ + --config-file configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml \ + OUTPUT_DIR outputs/ade20k_intern_image_huge WANDB.NAME ade20k_intern_image_huge +``` + ## Evaluation - You need to pass the value of `task` token. `task` belongs to [panoptic, semantic, instance]. diff --git a/INSTALL.md b/INSTALL.md index 5b9f753..a26017a 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -58,4 +58,12 @@ We use an evironment with the following specifications, packages and dependencie sh make.sh cd ../../../.. ``` - + +- Setup CUDA Kernel for DCNv3. Requires CUDA installed. + + ```bash + # Setup DCNv3 + cd oneformer/modeling/backbone/ops_dcnv3 + sh make.sh + cd ../../../.. + ``` diff --git a/README.md b/README.md index eeee56b..21a37e8 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer | DiNAT-L | 1280×1280 | 51.5 | 37.1 | 58.3 | 58.7 | 223M | [config](configs/ade20k/dinat/oneformer_dinat_large_bs16_160k_1280x1280.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/1280x1280_250_16_dinat_l_oneformer_ade20k_160k.pth) | | OneFormer (COCO-Pretrained) | DiNAT-L | 1280×1280 | 53.4 | 40.2 | 58.4 | 58.8 | 223M | [config](configs/ade20k/dinat/coco_pretrain_oneformer_dinat_large_bs16_160k_1280x1280_coco_pretrain.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/coco_pretrain_1280x1280_150_16_dinat_l_oneformer_ade20k_160k.pth) | [pretrained](https://shi-labs.com/projects/oneformer/coco/150_16_dinat_l_oneformer_coco_100ep.pth) | | OneFormer | ConvNeXt-XL | 640×640 | 50.1 | 36.3 | 57.4 | 58.8 | 372M | [config](configs/ade20k/convnext/oneformer_convnext_xlarge_bs16_160k.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/250_16_convnext_xl_oneformer_ade20k_160k.pth) | +| OneFormer | InternImage-H | 896×896 | 54.5 | 40.2 | 60.4 | 60.8 | 1.10B | [config](configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/896x896_250_16_intern_image_h_oneformer_ade20k_160k.pth) | ### Cityscapes @@ -108,6 +109,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer | DiNAT-L | 67.6 | 45.6 | 83.1 | 84.0 | 223M | [config](configs/cityscapes/dinat/oneformer_dinat_large_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_dinat_l_oneformer_cityscapes_90k.pth) | | OneFormer | ConvNeXt-XL | 68.4 | 46.7 | 83.6 | 84.6 | 372M | [config](configs/cityscapes/convnext/oneformer_convnext_xlarge_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_convnext_xl_oneformer_cityscapes_90k.pth) | | OneFormer (Mapillary Vistas-Pretrained) | ConvNeXt-XL | 69.7 | 48.9 | 84.5 | 85.8 | 372M | [config](configs/cityscapes/convnext/mapillary_pretrain_oneformer_convnext_xlarge_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/mapillary_pretrain_250_16_convnext_xl_oneformer_cityscapes_90k.pth) | [pretrained](https://shi-labs.com/projects/oneformer/mapillary/mapillary_pretrain_250_16_convnext_xl_oneformer_mapillary_300k.pth) | +| OneFormer | InternImage-H | 70.6 | 50.6 | 85.1 | 85.7 | 1.10B | [config](configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_intern_image_h_oneformer_cityscapes_90k.pth) | ### COCO diff --git a/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml b/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml new file mode 100644 index 0000000..186ea6c --- /dev/null +++ b/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml @@ -0,0 +1,39 @@ +_BASE_: ../oneformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2InternImage" + INTERNIMAGE: + CHANNELS: 320 + DEPTHS: [6, 6, 32, 6] + GROUPS: [10, 20, 40, 80] + WITH_CP: True + MLP_RATIO: 4.0 + DW_KERNEL_SIZE: 5 + LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] + WEIGHTS: "internimage_h_jointto22k_384.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + NUM_OBJECT_QUERIES: 250 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 896) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 896 + MAX_SIZE_TRAIN: 3584 + MAX_SIZE_TEST: 3584 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (896, 896) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 896 # used in dataset mapper + FORMAT: "RGB" +TEST: + DETECTIONS_PER_IMAGE: 250 + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [448, 678, 896, 1120, 1344, 1568] + MAX_SIZE: 6272 + FLIP: True diff --git a/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml b/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml new file mode 100644 index 0000000..1c1317c --- /dev/null +++ b/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml @@ -0,0 +1,53 @@ +_BASE_: ../oneformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2InternImage" + SEM_SEG_HEAD: + NAME: "OneFormerHead" + IGNORE_VALUE: 255 + NUM_CLASSES: 150 + LOSS_WEIGHT: 1.0 + CONVS_DIM: 1024 + MASK_DIM: 1024 + INTERNIMAGE: + CHANNELS: 320 + DEPTHS: [6, 6, 32, 6] + GROUPS: [10, 20, 40, 80] + WITH_CP: True + MLP_RATIO: 4.0 + DW_KERNEL_SIZE: 5 + LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] + WEIGHTS: "internimage_h_jointto22k_384.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + HIDDEN_DIM: 1024 + NUM_OBJECT_QUERIES: 250 + NHEADS: 32 + DIM_FEEDFORWARD: 4096 + TEXT_ENCODER: + WIDTH: 1024 + CONTEXT_LENGTH: 77 + N_CTX: 16 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 896) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 896 + MAX_SIZE_TRAIN: 3584 + MAX_SIZE_TEST: 3584 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (896, 896) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 896 # used in dataset mapper + FORMAT: "RGB" +TEST: + DETECTIONS_PER_IMAGE: 250 + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [448, 678, 896, 1120, 1344, 1568] + MAX_SIZE: 6272 + FLIP: True diff --git a/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k.yaml b/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k.yaml new file mode 100644 index 0000000..31d770e --- /dev/null +++ b/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k.yaml @@ -0,0 +1,19 @@ +_BASE_: ../oneformer_R50_bs16_90k.yaml +MODEL: + BACKBONE: + NAME: "D2InternImage" + INTERNIMAGE: + CHANNELS: 320 + DEPTHS: [6, 6, 32, 6] + GROUPS: [10, 20, 40, 80] + WITH_CP: True + MLP_RATIO: 4.0 + DW_KERNEL_SIZE: 5 + LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] + WEIGHTS: "internimage_h_jointto22k_384.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + NUM_OBJECT_QUERIES: 250 +TEST: + DETECTIONS_PER_IMAGE: 250 \ No newline at end of file diff --git a/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml b/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml new file mode 100644 index 0000000..66013af --- /dev/null +++ b/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml @@ -0,0 +1,33 @@ +_BASE_: ../oneformer_R50_bs16_90k.yaml +MODEL: + BACKBONE: + NAME: "D2InternImage" + SEM_SEG_HEAD: + NAME: "OneFormerHead" + IGNORE_VALUE: 255 + NUM_CLASSES: 150 + LOSS_WEIGHT: 1.0 + CONVS_DIM: 1024 + MASK_DIM: 1024 + INTERNIMAGE: + CHANNELS: 320 + DEPTHS: [6, 6, 32, 6] + GROUPS: [10, 20, 40, 80] + WITH_CP: True + MLP_RATIO: 4.0 + DW_KERNEL_SIZE: 5 + LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] + WEIGHTS: "internimage_h_jointto22k_384.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + HIDDEN_DIM: 1024 + NUM_OBJECT_QUERIES: 250 + NHEADS: 32 + DIM_FEEDFORWARD: 4096 + TEXT_ENCODER: + WIDTH: 1024 + CONTEXT_LENGTH: 77 + N_CTX: 16 +TEST: + DETECTIONS_PER_IMAGE: 250 \ No newline at end of file diff --git a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml new file mode 100644 index 0000000..c7e00e7 --- /dev/null +++ b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml @@ -0,0 +1,24 @@ +_BASE_: ../oneformer_R50_bs16_50ep.yaml +MODEL: + BACKBONE: + NAME: "D2InternImage" + INTERNIMAGE: + CHANNELS: 320 + DEPTHS: [6, 6, 32, 6] + GROUPS: [10, 20, 40, 80] + WITH_CP: True + MLP_RATIO: 4.0 + DW_KERNEL_SIZE: 5 + LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] + WEIGHTS: "internimage_h_jointto22k_384.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + NUM_OBJECT_QUERIES: 250 +SOLVER: + STEPS: (655556, 735184) + MAX_ITER: 737500 + AMP: + ENABLED: False +TEST: + DETECTIONS_PER_IMAGE: 250 diff --git a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml new file mode 100644 index 0000000..e6c9ba1 --- /dev/null +++ b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml @@ -0,0 +1,38 @@ +_BASE_: ../oneformer_R50_bs16_50ep.yaml +MODEL: + BACKBONE: + NAME: "D2InternImage" + SEM_SEG_HEAD: + NAME: "OneFormerHead" + IGNORE_VALUE: 255 + NUM_CLASSES: 150 + LOSS_WEIGHT: 1.0 + CONVS_DIM: 1024 + MASK_DIM: 1024 + INTERNIMAGE: + CHANNELS: 320 + DEPTHS: [6, 6, 32, 6] + GROUPS: [10, 20, 40, 80] + WITH_CP: True + MLP_RATIO: 4.0 + DW_KERNEL_SIZE: 5 + LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] + WEIGHTS: "internimage_h_jointto22k_384.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + HIDDEN_DIM: 1024 + NUM_OBJECT_QUERIES: 250 + NHEADS: 32 + DIM_FEEDFORWARD: 4096 + TEXT_ENCODER: + WIDTH: 1024 + CONTEXT_LENGTH: 77 + N_CTX: 16 +SOLVER: + STEPS: (655556, 735184) + MAX_ITER: 737500 + AMP: + ENABLED: False +TEST: + DETECTIONS_PER_IMAGE: 250 diff --git a/oneformer/config.py b/oneformer/config.py index 0dc8320..5058a74 100644 --- a/oneformer/config.py +++ b/oneformer/config.py @@ -3,7 +3,7 @@ from detectron2.config import CfgNode as CN __all__ = ["add_common_config", "add_oneformer_config", "add_swin_config", - "add_dinat_config", "add_convnext_config"] + "add_dinat_config", "add_convnext_config", "add_internimage_config"] def add_common_config(cfg): """ @@ -207,4 +207,29 @@ def add_convnext_config(cfg): cfg.MODEL.CONVNEXT.DROP_PATH_RATE = 0.4 cfg.MODEL.CONVNEXT.LSIT = 1.0 cfg.MODEL.CONVNEXT.OUT_INDICES = [0, 1, 2, 3] - cfg.MODEL.CONVNEXT.OUT_FEATURES = ["res2", "res3", "res4", "res5"] \ No newline at end of file + cfg.MODEL.CONVNEXT.OUT_FEATURES = ["res2", "res3", "res4", "res5"] + +def add_internimage_config(cfg): + ''' + Add config for InternImage Backbone. + ''' + + cfg.MODEL.INTERNIMAGE = CN() + cfg.MODEL.INTERNIMAGE.CORE_OP = 'DCNv3' + cfg.MODEL.INTERNIMAGE.CHANNELS = 320 + cfg.MODEL.INTERNIMAGE.DEPTHS = [6, 6, 32, 6] + cfg.MODEL.INTERNIMAGE.GROUPS = [10, 20, 40, 80] + cfg.MODEL.INTERNIMAGE.MLP_RATIO = 4. + cfg.MODEL.INTERNIMAGE.DROP_RATE = 0.5 + cfg.MODEL.INTERNIMAGE.norm_layer = 'LN' + cfg.MODEL.INTERNIMAGE.LAYER_SCALE = None + cfg.MODEL.INTERNIMAGE.OFFSET_SCALE = 1.0 + cfg.MODEL.INTERNIMAGE.POST_NORM = False + cfg.MODEL.INTERNIMAGE.DW_KERNEL_SIZE = 5 + cfg.MODEL.INTERNIMAGE.RES_POST_NORM = True + cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM = True + cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM_BLOCK_IDS = [5, 11, 17, 23, 29] + cfg.MODEL.INTERNIMAGE.CENTER_FEATURE_SCALE = True + cfg.MODEL.INTERNIMAGE.WITH_CP = True + cfg.MODEL.INTERNIMAGE.OUT_INDICES = [0, 1, 2, 3] + cfg.MODEL.INTERNIMAGE.OUT_FEATURES = ["res2", "res3", "res4", "res5"] \ No newline at end of file diff --git a/oneformer/modeling/__init__.py b/oneformer/modeling/__init__.py index 0b0bb8a..599d209 100644 --- a/oneformer/modeling/__init__.py +++ b/oneformer/modeling/__init__.py @@ -1,6 +1,7 @@ from .backbone.swin import D2SwinTransformer from .backbone.dinat import D2DiNAT from .backbone.convnext import D2ConvNeXt +from .backbone.intern_image import D2InternImage from .pixel_decoder.fpn import BasePixelDecoder from .pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder from .meta_arch.oneformer_head import OneFormerHead diff --git a/oneformer/modeling/backbone/intern_image.py b/oneformer/modeling/backbone/intern_image.py new file mode 100644 index 0000000..063515d --- /dev/null +++ b/oneformer/modeling/backbone/intern_image.py @@ -0,0 +1,747 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import torch +import torch.nn as nn +from collections import OrderedDict +import torch.utils.checkpoint as checkpoint +from timm.models.layers import trunc_normal_, DropPath +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec +import torch.nn.functional as F + +from oneformer.modeling.backbone.ops_dcnv3 import modules as opsm + + +class to_channels_first(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 3, 1, 2) + + +class to_channels_last(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 2, 3, 1) + + +def build_norm_layer(dim, + norm_layer, + in_format='channels_last', + out_format='channels_last', + eps=1e-6): + layers = [] + if norm_layer == 'BN': + if in_format == 'channels_last': + layers.append(to_channels_first()) + layers.append(nn.BatchNorm2d(dim)) + if out_format == 'channels_last': + layers.append(to_channels_last()) + elif norm_layer == 'LN': + if in_format == 'channels_first': + layers.append(to_channels_last()) + layers.append(nn.LayerNorm(dim, eps=eps)) + if out_format == 'channels_first': + layers.append(to_channels_first()) + else: + raise NotImplementedError( + f'build_norm_layer does not support {norm_layer}') + return nn.Sequential(*layers) + + +def build_act_layer(act_layer): + if act_layer == 'ReLU': + return nn.ReLU(inplace=True) + elif act_layer == 'SiLU': + return nn.SiLU(inplace=True) + elif act_layer == 'GELU': + return nn.GELU() + + raise NotImplementedError(f'build_act_layer does not support {act_layer}') + + +class CrossAttention(nn.Module): + r""" Cross Attention Module + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + attn_head_dim (int, optional): Dimension of attention head. + out_dim (int, optional): Dimension of output. + """ + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + attn_head_dim=None, + out_dim=None): + super().__init__() + if out_dim is None: + out_dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + assert all_head_dim == dim + + self.q = nn.Linear(dim, all_head_dim, bias=False) + self.k = nn.Linear(dim, all_head_dim, bias=False) + self.v = nn.Linear(dim, all_head_dim, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, out_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, k=None, v=None): + B, N, C = x.shape + N_k = k.shape[1] + N_v = v.shape[1] + + q_bias, k_bias, v_bias = None, None, None + if self.q_bias is not None: + q_bias = self.q_bias + k_bias = self.k_bias + v_bias = self.v_bias + + q = F.linear(input=x, weight=self.q.weight, bias=q_bias) + q = q.reshape(B, N, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, N_head, N_q, dim) + + k = F.linear(input=k, weight=self.k.weight, bias=k_bias) + k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, + 4).squeeze(0) + + v = F.linear(input=v, weight=self.v.weight, bias=v_bias) + v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, + 4).squeeze(0) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class AttentiveBlock(nn.Module): + r"""Attentive Block + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop (float, optional): Dropout rate. Default: 0.0. + attn_drop (float, optional): Attention dropout rate. Default: 0.0. + drop_path (float | tuple[float], optional): Stochastic depth rate. + Default: 0.0. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm. + attn_head_dim (int, optional): Dimension of attention head. Default: None. + out_dim (int, optional): Dimension of output. Default: None. + """ + + def __init__(self, + dim, + num_heads, + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer="LN", + attn_head_dim=None, + out_dim=None): + super().__init__() + + self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6) + self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6) + self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6) + self.cross_dcn = CrossAttention(dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + attn_head_dim=attn_head_dim, + out_dim=out_dim) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, + x_q, + x_kv, + pos_q, + pos_k, + bool_masked_pos, + rel_pos_bias=None): + x_q = self.norm1_q(x_q + pos_q) + x_k = self.norm1_k(x_kv + pos_k) + x_v = self.norm1_v(x_kv) + + x = self.cross_dcn(x_q, k=x_k, v=x_v) + + return x + + +class AttentionPoolingBlock(AttentiveBlock): + + def forward(self, x): + x_q = x.mean(1, keepdim=True) + x_kv = x + pos_q, pos_k = 0, 0 + x = super().forward(x_q, x_kv, pos_q, pos_k, + bool_masked_pos=None, + rel_pos_bias=None) + x = x.squeeze(1) + return x + + +class StemLayer(nn.Module): + r""" Stem layer of InternImage + Args: + in_chans (int): number of input channels + out_chans (int): number of output channels + act_layer (str): activation layer + norm_layer (str): normalization layer + """ + + def __init__(self, + in_chans=3, + out_chans=96, + act_layer='GELU', + norm_layer='BN'): + super().__init__() + self.conv1 = nn.Conv2d(in_chans, + out_chans // 2, + kernel_size=3, + stride=2, + padding=1) + self.norm1 = build_norm_layer(out_chans // 2, norm_layer, + 'channels_first', 'channels_first') + self.act = build_act_layer(act_layer) + self.conv2 = nn.Conv2d(out_chans // 2, + out_chans, + kernel_size=3, + stride=2, + padding=1) + self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first', + 'channels_last') + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.act(x) + x = self.conv2(x) + x = self.norm2(x) + return x + + +class DownsampleLayer(nn.Module): + r""" Downsample layer of InternImage + Args: + channels (int): number of input channels + norm_layer (str): normalization layer + """ + + def __init__(self, channels, norm_layer='LN'): + super().__init__() + self.conv = nn.Conv2d(channels, + 2 * channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm = build_norm_layer(2 * channels, norm_layer, + 'channels_first', 'channels_last') + + def forward(self, x): + x = self.conv(x.permute(0, 3, 1, 2)) + x = self.norm(x) + return x + + +class MLPLayer(nn.Module): + r""" MLP layer of InternImage + Args: + in_features (int): number of input features + hidden_features (int): number of hidden features + out_features (int): number of output features + act_layer (str): activation layer + drop (float): dropout rate + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer='GELU', + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = build_act_layer(act_layer) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class InternImageLayer(nn.Module): + r""" Basic layer of InternImage + Args: + core_op (nn.Module): core operation of InternImage + channels (int): number of input channels + groups (list): Groups of each block. + mlp_ratio (float): ratio of mlp hidden features to input channels + drop (float): dropout rate + drop_path (float): drop path rate + act_layer (str): activation layer + norm_layer (str): normalization layer + post_norm (bool): whether to use post normalization + layer_scale (float): layer scale + offset_scale (float): offset scale + with_cp (bool): whether to use checkpoint + """ + + def __init__(self, + core_op, + channels, + groups, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer='GELU', + norm_layer='LN', + post_norm=False, + layer_scale=None, + offset_scale=1.0, + with_cp=False, + dw_kernel_size=None, # for InternImage-H/G + res_post_norm=False, # for InternImage-H/G + center_feature_scale=False): # for InternImage-H/G + super().__init__() + self.channels = channels + self.groups = groups + self.mlp_ratio = mlp_ratio + self.with_cp = with_cp + + self.norm1 = build_norm_layer(channels, 'LN') + self.post_norm = post_norm + self.dcn = core_op( + channels=channels, + kernel_size=3, + stride=1, + pad=1, + dilation=1, + group=groups, + offset_scale=offset_scale, + act_layer=act_layer, + norm_layer=norm_layer, + dw_kernel_size=dw_kernel_size, # for InternImage-H/G + center_feature_scale=center_feature_scale) # for InternImage-H/G + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.norm2 = build_norm_layer(channels, 'LN') + self.mlp = MLPLayer(in_features=channels, + hidden_features=int(channels * mlp_ratio), + act_layer=act_layer, + drop=drop) + self.layer_scale = layer_scale is not None + if self.layer_scale: + self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels), + requires_grad=True) + self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels), + requires_grad=True) + self.res_post_norm = res_post_norm + if res_post_norm: + self.res_post_norm1 = build_norm_layer(channels, 'LN') + self.res_post_norm2 = build_norm_layer(channels, 'LN') + + def forward(self, x): + + def _inner_forward(x): + if not self.layer_scale: + if self.post_norm: + x = x + self.drop_path(self.norm1(self.dcn(x))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + elif self.res_post_norm: # for InternImage-H/G + x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x)))) + x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x)))) + else: + x = x + self.drop_path(self.dcn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + if self.post_norm: + x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x))) + x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x))) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = checkpoint.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +class InternImageBlock(nn.Module): + r""" Block of InternImage + Args: + core_op (nn.Module): core operation of InternImage + channels (int): number of input channels + depths (list): Depth of each block. + groups (list): Groups of each block. + mlp_ratio (float): ratio of mlp hidden features to input channels + drop (float): dropout rate + drop_path (float): drop path rate + act_layer (str): activation layer + norm_layer (str): normalization layer + post_norm (bool): whether to use post normalization + layer_scale (float): layer scale + offset_scale (float): offset scale + with_cp (bool): whether to use checkpoint + """ + + def __init__(self, + core_op, + channels, + depth, + groups, + downsample=True, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer='GELU', + norm_layer='LN', + post_norm=False, + offset_scale=1.0, + layer_scale=None, + with_cp=False, + dw_kernel_size=None, # for InternImage-H/G + post_norm_block_ids=None, # for InternImage-H/G + res_post_norm=False, # for InternImage-H/G + center_feature_scale=False): # for InternImage-H/G + super().__init__() + self.channels = channels + self.depth = depth + self.post_norm = post_norm + self.center_feature_scale = center_feature_scale + + self.blocks = nn.ModuleList([ + InternImageLayer( + core_op=core_op, + channels=channels, + groups=groups, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + post_norm=post_norm, + layer_scale=layer_scale, + offset_scale=offset_scale, + with_cp=with_cp, + dw_kernel_size=dw_kernel_size, # for InternImage-H/G + res_post_norm=res_post_norm, # for InternImage-H/G + center_feature_scale=center_feature_scale # for InternImage-H/G + ) for i in range(depth) + ]) + if not self.post_norm or center_feature_scale: + self.norm = build_norm_layer(channels, 'LN') + self.post_norm_block_ids = post_norm_block_ids + if post_norm_block_ids is not None: # for InternImage-H/G + self.post_norms = nn.ModuleList( + [build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids] + ) + self.downsample = DownsampleLayer( + channels=channels, norm_layer=norm_layer) if downsample else None + + def forward(self, x, return_wo_downsample=False): + for i, blk in enumerate(self.blocks): + x = blk(x) + if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids): + index = self.post_norm_block_ids.index(i) + x = self.post_norms[index](x) # for InternImage-H/G + if not self.post_norm or self.center_feature_scale: + x = self.norm(x) + if return_wo_downsample: + x_ = x + if self.downsample is not None: + x = self.downsample(x) + + if return_wo_downsample: + return x, x_ + return x + + +class InternImage(nn.Module): + r""" InternImage + A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` - + https://arxiv.org/pdf/2103.14030 + Args: + core_op (str): Core operator. Default: 'DCNv3' + channels (int): Number of the first stage. Default: 64 + depths (list): Depth of each block. Default: [3, 4, 18, 5] + groups (list): Groups of each block. Default: [3, 6, 12, 24] + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + drop_rate (float): Probability of an element to be zeroed. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0. + act_layer (str): Activation layer. Default: 'GELU' + norm_layer (str): Normalization layer. Default: 'LN' + layer_scale (bool): Whether to use layer scale. Default: False + cls_scale (bool): Whether to use class scale. Default: False + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + dw_kernel_size (int): Size of the dwconv. Default: None + level2_post_norm (bool): Whether to use level2 post norm. Default: False + level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None + res_post_norm (bool): Whether to use res post norm. Default: False + center_feature_scale (bool): Whether to use center feature scale. Default: False + """ + + def __init__(self, + core_op='DCNv3', + channels=64, + depths=[3, 4, 18, 5], + groups=[3, 6, 12, 24], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.2, + drop_path_type='linear', + act_layer='GELU', + norm_layer='LN', + layer_scale=None, + offset_scale=1.0, + post_norm=False, + with_cp=False, + dw_kernel_size=None, # for InternImage-H/G + level2_post_norm=False, # for InternImage-H/G + level2_post_norm_block_ids=None, # for InternImage-H/G + res_post_norm=False, # for InternImage-H/G + center_feature_scale=False, # for InternImage-H/G + out_indices=(0, 1, 2, 3), + **kwargs): + super().__init__() + self.core_op = core_op + self.num_levels = len(depths) + self.depths = depths + self.channels = channels + self.num_features = [int(channels * 2 ** i) for i in range(self.num_levels)] + self.post_norm = post_norm + self.mlp_ratio = mlp_ratio + self.out_indices = out_indices + self.level2_post_norm_block_ids = level2_post_norm_block_ids + + in_chans = 3 + self.patch_embed = StemLayer(in_chans=in_chans, + out_chans=channels, + act_layer=act_layer, + norm_layer=norm_layer) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + if drop_path_type == 'uniform': + for i in range(len(dpr)): + dpr[i] = drop_path_rate + + self.levels = nn.ModuleList() + for i in range(self.num_levels): + post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and ( + i == 2) else None # for InternImage-H/G + level = InternImageBlock( + core_op=getattr(opsm, core_op), + channels=int(channels * 2**i), + depth=depths[i], + groups=groups[i], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], + act_layer=act_layer, + norm_layer=norm_layer, + post_norm=post_norm, + downsample=(i < self.num_levels - 1), + layer_scale=layer_scale, + offset_scale=offset_scale, + with_cp=with_cp, + dw_kernel_size=dw_kernel_size, # for InternImage-H/G + post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G + res_post_norm=res_post_norm, # for InternImage-H/G + center_feature_scale=center_feature_scale # for InternImage-H/G + ) + self.levels.append(level) + + self.num_layers = len(depths) + self.apply(self._init_weights) + self.apply(self._init_deform_weights) + + def init_weights(self): + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(m, getattr(opsm, self.core_op)): + m._reset_parameters() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _init_deform_weights(self, m): + if isinstance(m, getattr(opsm, self.core_op)): + m._reset_parameters() + + def forward(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x) + + outs = {} + for level_idx, level in enumerate(self.levels): + x, x_ = level(x, return_wo_downsample=True) + if level_idx in self.out_indices: + outs["res{}".format(level_idx + 2)] = x_.permute(0, 3, 1, 2).contiguous() + + return outs + +@BACKBONE_REGISTRY.register() +class D2InternImage(InternImage, Backbone): + def __init__(self, cfg, input_shape): + + core_op = cfg.MODEL.INTERNIMAGE.CORE_OP + channels = cfg.MODEL.INTERNIMAGE.CHANNELS + depths = cfg.MODEL.INTERNIMAGE.DEPTHS + groups = cfg.MODEL.INTERNIMAGE.GROUPS + mlp_ratio = cfg.MODEL.INTERNIMAGE.MLP_RATIO + drop_path_rate = cfg.MODEL.INTERNIMAGE.DROP_RATE + norm_layer = cfg.MODEL.INTERNIMAGE.norm_layer + layer_scale = cfg.MODEL.INTERNIMAGE.LAYER_SCALE + offset_scale = cfg.MODEL.INTERNIMAGE.OFFSET_SCALE + post_norm = cfg.MODEL.INTERNIMAGE.POST_NORM + dw_kernel_size = cfg.MODEL.INTERNIMAGE.DW_KERNEL_SIZE + res_post_norm = cfg.MODEL.INTERNIMAGE.RES_POST_NORM + level2_post_norm = cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM + level2_post_norm_block_ids = cfg.MODEL.INTERNIMAGE.LEVEL2_POST_NORM_BLOCK_IDS + center_feature_scale = cfg.MODEL.INTERNIMAGE.CENTER_FEATURE_SCALE + with_cp = cfg.MODEL.INTERNIMAGE.WITH_CP + out_indices = cfg.MODEL.INTERNIMAGE.OUT_INDICES + + super().__init__( + core_op=core_op, + channels=channels, + depths=depths, + groups=groups, + mlp_ratio=mlp_ratio, + drop_path_rate=drop_path_rate, + norm_layer=norm_layer, + layer_scale=layer_scale, + offset_scale=offset_scale, + post_norm=post_norm, + dw_kernel_size=dw_kernel_size, + res_post_norm=res_post_norm, + level2_post_norm=level2_post_norm, + level2_post_norm_block_ids=level2_post_norm_block_ids, + center_feature_scale=center_feature_scale, + with_cp=with_cp, + out_indices=out_indices, + ) + + self._out_features = cfg.MODEL.INTERNIMAGE.OUT_FEATURES + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert ( + x.dim() == 4 + ), f"InternImage takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + y = super().forward(x) + for k in y.keys(): + if k in self._out_features: + outputs[k] = y[k] + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + @property + def size_divisibility(self): + return 32 \ No newline at end of file diff --git a/oneformer/modeling/backbone/ops_dcnv3/functions/__init__.py b/oneformer/modeling/backbone/ops_dcnv3/functions/__init__.py new file mode 100644 index 0000000..0634879 --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/functions/__init__.py @@ -0,0 +1,7 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch diff --git a/oneformer/modeling/backbone/ops_dcnv3/functions/dcnv3_func.py b/oneformer/modeling/backbone/ops_dcnv3/functions/dcnv3_func.py new file mode 100644 index 0000000..4dac8fb --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/functions/dcnv3_func.py @@ -0,0 +1,188 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd +import DCNv3 + + +class DCNv3Function(Function): + @staticmethod + @custom_fwd + def forward( + ctx, input, offset, mask, + kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, + group, group_channels, offset_scale, im2col_step): + ctx.kernel_h = kernel_h + ctx.kernel_w = kernel_w + ctx.stride_h = stride_h + ctx.stride_w = stride_w + ctx.pad_h = pad_h + ctx.pad_w = pad_w + ctx.dilation_h = dilation_h + ctx.dilation_w = dilation_w + ctx.group = group + ctx.group_channels = group_channels + ctx.offset_scale = offset_scale + ctx.im2col_step = im2col_step + output = DCNv3.dcnv3_forward( + input, offset, mask, kernel_h, + kernel_w, stride_h, stride_w, pad_h, + pad_w, dilation_h, dilation_w, group, + group_channels, offset_scale, ctx.im2col_step) + ctx.save_for_backward(input, offset, mask) + + return output + + @staticmethod + @once_differentiable + @custom_bwd + def backward(ctx, grad_output): + input, offset, mask = ctx.saved_tensors + grad_input, grad_offset, grad_mask = \ + DCNv3.dcnv3_backward( + input, offset, mask, ctx.kernel_h, + ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, + ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, + ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step) + + return grad_input, grad_offset, grad_mask, \ + None, None, None, None, None, None, None, None, None, None, None, None + + @staticmethod + def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, offset_scale, im2col_step): + """Symbolic function for mmdeploy::DCNv3. + + Returns: + DCNv3 op for onnx. + """ + return g.op( + 'mmdeploy::TRTDCNv3', + input, + offset, + mask, + kernel_h_i=int(kernel_h), + kernel_w_i=int(kernel_w), + stride_h_i=int(stride_h), + stride_w_i=int(stride_w), + pad_h_i=int(pad_h), + pad_w_i=int(pad_w), + dilation_h_i=int(dilation_h), + dilation_w_i=int(dilation_w), + group_i=int(group), + group_channels_i=int(group_channels), + offset_scale_f=float(offset_scale), + im2col_step_i=int(im2col_step), + ) + +def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): + _, H_, W_, _ = spatial_shapes + H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + + ref_y, ref_x = torch.meshgrid( + torch.linspace( + # pad_h + 0.5, + # H_ - pad_h - 0.5, + (dilation_h * (kernel_h - 1)) // 2 + 0.5, + (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, + H_out, + dtype=torch.float32, + device=device), + torch.linspace( + # pad_w + 0.5, + # W_ - pad_w - 0.5, + (dilation_w * (kernel_w - 1)) // 2 + 0.5, + (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, + W_out, + dtype=torch.float32, + device=device)) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + + ref = torch.stack((ref_x, ref_y), -1).reshape( + 1, H_out, W_out, 1, 2) + + return ref + + +def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): + _, H_, W_, _ = spatial_shapes + points_list = [] + x, y = torch.meshgrid( + torch.linspace( + -((dilation_w * (kernel_w - 1)) // 2), + -((dilation_w * (kernel_w - 1)) // 2) + + (kernel_w - 1) * dilation_w, kernel_w, + dtype=torch.float32, + device=device), + torch.linspace( + -((dilation_h * (kernel_h - 1)) // 2), + -((dilation_h * (kernel_h - 1)) // 2) + + (kernel_h - 1) * dilation_h, kernel_h, + dtype=torch.float32, + device=device)) + + points_list.extend([x / W_, y / H_]) + grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ + repeat(1, group, 1).permute(1, 0, 2) + grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) + + return grid + + +def dcnv3_core_pytorch( + input, offset, mask, kernel_h, + kernel_w, stride_h, stride_w, pad_h, + pad_w, dilation_h, dilation_w, group, + group_channels, offset_scale): + # for debug and test only, + # need to use cuda version instead + input = F.pad( + input, + [0, 0, pad_h, pad_h, pad_w, pad_w]) + N_, H_in, W_in, _ = input.shape + _, H_out, W_out, _ = offset.shape + + ref = _get_reference_points( + input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) + grid = _generate_dilation_grids( + input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) + spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ + repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) + + sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ + offset * offset_scale / spatial_norm + + P_ = kernel_h * kernel_w + sampling_grids = 2 * sampling_locations - 1 + # N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in + input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ + reshape(N_*group, group_channels, H_in, W_in) + # N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2 + sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ + flatten(0, 1) + # N_*group, group_channels, H_out*W_out, P_ + sampling_input_ = F.grid_sample( + input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) + + # (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_) + mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ + reshape(N_*group, 1, H_out*W_out, P_) + output = (sampling_input_ * mask).sum(-1).view(N_, + group*group_channels, H_out*W_out) + + return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() diff --git a/oneformer/modeling/backbone/ops_dcnv3/make.sh b/oneformer/modeling/backbone/ops_dcnv3/make.sh new file mode 100755 index 0000000..9a50179 --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/make.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +python setup.py build install diff --git a/oneformer/modeling/backbone/ops_dcnv3/modules/__init__.py b/oneformer/modeling/backbone/ops_dcnv3/modules/__init__.py new file mode 100644 index 0000000..47216fd --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/modules/__init__.py @@ -0,0 +1,7 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from .dcnv3 import DCNv3, DCNv3_pytorch \ No newline at end of file diff --git a/oneformer/modeling/backbone/ops_dcnv3/modules/dcnv3.py b/oneformer/modeling/backbone/ops_dcnv3/modules/dcnv3.py new file mode 100644 index 0000000..a7d0650 --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/modules/dcnv3.py @@ -0,0 +1,345 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ +from ..functions import DCNv3Function, dcnv3_core_pytorch + + +class to_channels_first(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 3, 1, 2) + + +class to_channels_last(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 2, 3, 1) + + +def build_norm_layer(dim, + norm_layer, + in_format='channels_last', + out_format='channels_last', + eps=1e-6): + layers = [] + if norm_layer == 'BN': + if in_format == 'channels_last': + layers.append(to_channels_first()) + layers.append(nn.BatchNorm2d(dim)) + if out_format == 'channels_last': + layers.append(to_channels_last()) + elif norm_layer == 'LN': + if in_format == 'channels_first': + layers.append(to_channels_last()) + layers.append(nn.LayerNorm(dim, eps=eps)) + if out_format == 'channels_first': + layers.append(to_channels_first()) + else: + raise NotImplementedError( + f'build_norm_layer does not support {norm_layer}') + return nn.Sequential(*layers) + + +def build_act_layer(act_layer): + if act_layer == 'ReLU': + return nn.ReLU(inplace=True) + elif act_layer == 'SiLU': + return nn.SiLU(inplace=True) + elif act_layer == 'GELU': + return nn.GELU() + + raise NotImplementedError(f'build_act_layer does not support {act_layer}') + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + + return (n & (n - 1) == 0) and n != 0 + + +class CenterFeatureScaleModule(nn.Module): + def forward(self, + query, + center_feature_scale_proj_weight, + center_feature_scale_proj_bias): + center_feature_scale = F.linear(query, + weight=center_feature_scale_proj_weight, + bias=center_feature_scale_proj_bias).sigmoid() + return center_feature_scale + + +class DCNv3_pytorch(nn.Module): + def __init__( + self, + channels=64, + kernel_size=3, + dw_kernel_size=None, + stride=1, + pad=1, + dilation=1, + group=4, + offset_scale=1.0, + act_layer='GELU', + norm_layer='LN', + center_feature_scale=False): + """ + DCNv3 Module + :param channels + :param kernel_size + :param stride + :param pad + :param dilation + :param group + :param offset_scale + :param act_layer + :param norm_layer + """ + super().__init__() + if channels % group != 0: + raise ValueError( + f'channels must be divisible by group, but got {channels} and {group}') + _d_per_group = channels // group + dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size + # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_group): + warnings.warn( + "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.offset_scale = offset_scale + self.channels = channels + self.kernel_size = kernel_size + self.dw_kernel_size = dw_kernel_size + self.stride = stride + self.dilation = dilation + self.pad = pad + self.group = group + self.group_channels = channels // group + self.offset_scale = offset_scale + self.center_feature_scale = center_feature_scale + + self.dw_conv = nn.Sequential( + nn.Conv2d( + channels, + channels, + kernel_size=dw_kernel_size, + stride=1, + padding=(dw_kernel_size - 1) // 2, + groups=channels), + build_norm_layer( + channels, + norm_layer, + 'channels_first', + 'channels_last'), + build_act_layer(act_layer)) + self.offset = nn.Linear( + channels, + group * kernel_size * kernel_size * 2) + self.mask = nn.Linear( + channels, + group * kernel_size * kernel_size) + self.input_proj = nn.Linear(channels, channels) + self.output_proj = nn.Linear(channels, channels) + self._reset_parameters() + + if center_feature_scale: + self.center_feature_scale_proj_weight = nn.Parameter( + torch.zeros((group, channels), dtype=torch.float)) + self.center_feature_scale_proj_bias = nn.Parameter( + torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) + self.center_feature_scale_module = CenterFeatureScaleModule() + + def _reset_parameters(self): + constant_(self.offset.weight.data, 0.) + constant_(self.offset.bias.data, 0.) + constant_(self.mask.weight.data, 0.) + constant_(self.mask.bias.data, 0.) + xavier_uniform_(self.input_proj.weight.data) + constant_(self.input_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, input): + """ + :param query (N, H, W, C) + :return output (N, H, W, C) + """ + N, H, W, _ = input.shape + + x = self.input_proj(input) + x_proj = x + + x1 = input.permute(0, 3, 1, 2) + x1 = self.dw_conv(x1) + offset = self.offset(x1) + mask = self.mask(x1).reshape(N, H, W, self.group, -1) + mask = F.softmax(mask, -1).reshape(N, H, W, -1) + + x = dcnv3_core_pytorch( + x, offset, mask, + self.kernel_size, self.kernel_size, + self.stride, self.stride, + self.pad, self.pad, + self.dilation, self.dilation, + self.group, self.group_channels, + self.offset_scale) + if self.center_feature_scale: + center_feature_scale = self.center_feature_scale_module( + x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) + # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels + center_feature_scale = center_feature_scale[..., None].repeat( + 1, 1, 1, 1, self.channels // self.group).flatten(-2) + x = x * (1 - center_feature_scale) + x_proj * center_feature_scale + x = self.output_proj(x) + + return x + + +class DCNv3(nn.Module): + def __init__( + self, + channels=64, + kernel_size=3, + dw_kernel_size=None, + stride=1, + pad=1, + dilation=1, + group=4, + offset_scale=1.0, + act_layer='GELU', + norm_layer='LN', + center_feature_scale=False): + """ + DCNv3 Module + :param channels + :param kernel_size + :param stride + :param pad + :param dilation + :param group + :param offset_scale + :param act_layer + :param norm_layer + """ + super().__init__() + if channels % group != 0: + raise ValueError( + f'channels must be divisible by group, but got {channels} and {group}') + _d_per_group = channels // group + dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size + # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_group): + warnings.warn( + "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.offset_scale = offset_scale + self.channels = channels + self.kernel_size = kernel_size + self.dw_kernel_size = dw_kernel_size + self.stride = stride + self.dilation = dilation + self.pad = pad + self.group = group + self.group_channels = channels // group + self.offset_scale = offset_scale + self.center_feature_scale = center_feature_scale + + self.dw_conv = nn.Sequential( + nn.Conv2d( + channels, + channels, + kernel_size=dw_kernel_size, + stride=1, + padding=(dw_kernel_size - 1) // 2, + groups=channels), + build_norm_layer( + channels, + norm_layer, + 'channels_first', + 'channels_last'), + build_act_layer(act_layer)) + self.offset = nn.Linear( + channels, + group * kernel_size * kernel_size * 2) + self.mask = nn.Linear( + channels, + group * kernel_size * kernel_size) + self.input_proj = nn.Linear(channels, channels) + self.output_proj = nn.Linear(channels, channels) + self._reset_parameters() + + if center_feature_scale: + self.center_feature_scale_proj_weight = nn.Parameter( + torch.zeros((group, channels), dtype=torch.float)) + self.center_feature_scale_proj_bias = nn.Parameter( + torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) + self.center_feature_scale_module = CenterFeatureScaleModule() + + def _reset_parameters(self): + constant_(self.offset.weight.data, 0.) + constant_(self.offset.bias.data, 0.) + constant_(self.mask.weight.data, 0.) + constant_(self.mask.bias.data, 0.) + xavier_uniform_(self.input_proj.weight.data) + constant_(self.input_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, input): + """ + :param query (N, H, W, C) + :return output (N, H, W, C) + """ + N, H, W, _ = input.shape + + x = self.input_proj(input) + x_proj = x + dtype = x.dtype + + x1 = input.permute(0, 3, 1, 2) + x1 = self.dw_conv(x1) + offset = self.offset(x1) + mask = self.mask(x1).reshape(N, H, W, self.group, -1) + mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype) + + x = DCNv3Function.apply( + x, offset, mask, + self.kernel_size, self.kernel_size, + self.stride, self.stride, + self.pad, self.pad, + self.dilation, self.dilation, + self.group, self.group_channels, + self.offset_scale, + 256) + + if self.center_feature_scale: + center_feature_scale = self.center_feature_scale_module( + x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) + # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels + center_feature_scale = center_feature_scale[..., None].repeat( + 1, 1, 1, 1, self.channels // self.group).flatten(-2) + x = x * (1 - center_feature_scale) + x_proj * center_feature_scale + x = self.output_proj(x) + + return x diff --git a/oneformer/modeling/backbone/ops_dcnv3/setup.py b/oneformer/modeling/backbone/ops_dcnv3/setup.py new file mode 100644 index 0000000..2bd813d --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/setup.py @@ -0,0 +1,75 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + # "-DCUDA_HAS_FP16=1", + # "-D__CUDA_NO_HALF_OPERATORS__", + # "-D__CUDA_NO_HALF_CONVERSIONS__", + # "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not available') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "DCNv3", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + + +setup( + name="DCNv3", + version="1.0", + author="InternImage", + url="https://github.com/OpenGVLab/InternImage", + description= + "PyTorch Wrapper for CUDA Functions of DCNv3", + packages=find_packages(exclude=( + "configs", + "tests", + )), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/oneformer/modeling/backbone/ops_dcnv3/src/cpu/dcnv3_cpu.cpp b/oneformer/modeling/backbone/ops_dcnv3/src/cpu/dcnv3_cpu.cpp new file mode 100644 index 0000000..a3bddc1 --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/src/cpu/dcnv3_cpu.cpp @@ -0,0 +1,37 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + +at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const int im2col_step) { + AT_ERROR("Not implement on cpu"); +} + +std::vector +dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const at::Tensor &grad_output, const int im2col_step) { + AT_ERROR("Not implement on cpu"); +} diff --git a/oneformer/modeling/backbone/ops_dcnv3/src/cpu/dcnv3_cpu.h b/oneformer/modeling/backbone/ops_dcnv3/src/cpu/dcnv3_cpu.h new file mode 100644 index 0000000..d457bcb --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/src/cpu/dcnv3_cpu.h @@ -0,0 +1,31 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const int im2col_step); + +std::vector +dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const at::Tensor &grad_output, const int im2col_step); diff --git a/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_cuda.cu b/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_cuda.cu new file mode 100644 index 0000000..5284095 --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_cuda.cu @@ -0,0 +1,174 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "cuda/dcnv3_im2col_cuda.cuh" +#include + +#include +#include +#include +#include +#include + +at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, + const float offset_scale, const int im2col_step) { + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels wont match: (%d vs %d).", + channels, group * group_channels); + + auto output = + at::zeros({batch, height_out, width_out, group * group_channels}, + input.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch / batch_n, batch_n, height_out, + width_out, group * group_channels}); + auto per_input_size = height_in * width_in * group * group_channels; + auto per_offset_size = + height_out * width_out * group * kernel_h * kernel_w * 2; + auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w; + for (int n = 0; n < batch / im2col_step_; ++n) { + auto columns = output_n.select(0, n); + // AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "ms_deform_attn_forward_cuda", ([&] { + dcnv3_im2col_cuda( + at::cuda::getCurrentCUDAStream(), + input.data() + n * im2col_step_ * per_input_size, + offset.data() + + n * im2col_step_ * per_offset_size, + mask.data() + n * im2col_step_ * per_mask_size, + columns.data(), kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, batch_n, height_in, width_in, height_out, + width_out, offset_scale); + })); + } + + return output; +} + +std::vector +dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const at::Tensor &grad_output, const int im2col_step) { + + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), + "grad_output tensor has to be contiguous"); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), + "grad_output must be a CUDA tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels wont match: (%d vs %d).", + channels, group * group_channels); + + auto dtype = input.dtype(); + if (dtype == at::kHalf) { + dtype = at::kFloat; + } + + auto grad_input = at::zeros_like(input, dtype); + auto grad_offset = at::zeros_like(offset, dtype); + auto grad_mask = at::zeros_like(mask, dtype); + + const int batch_n = im2col_step_; + auto per_input_size = height_in * width_in * group * group_channels; + auto per_offset_size = + height_out * width_out * group * kernel_h * kernel_w * 2; + auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w; + auto grad_output_n = + grad_output.view({batch / im2col_step_, batch_n, height_out * width_out, + group, group_channels}); + + for (int n = 0; n < batch / im2col_step_; ++n) { + auto grad_output_g = grad_output_n.select(0, n); + // AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "ms_deform_attn_backward_cuda", ([&] { + dcnv3_col2im_cuda( + at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + input.data() + n * im2col_step_ * per_input_size, + offset.data() + + n * im2col_step_ * per_offset_size, + mask.data() + n * im2col_step_ * per_mask_size, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, batch_n, + height_in, width_in, height_out, width_out, offset_scale, + grad_input.data() + + n * im2col_step_ * per_input_size, + grad_offset.data() + + n * im2col_step_ * per_offset_size, + grad_mask.data() + + n * im2col_step_ * per_mask_size); + })); + } + + if (input.dtype() == torch::kHalf) { + return {grad_input.to(torch::kHalf), grad_offset.to(torch::kHalf), + grad_mask.to(torch::kHalf)}; + } else { + return {grad_input, grad_offset, grad_mask}; + } +} \ No newline at end of file diff --git a/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_cuda.h b/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_cuda.h new file mode 100644 index 0000000..069f282 --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_cuda.h @@ -0,0 +1,31 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, + const float offset_scale, const int im2col_step); + +std::vector +dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const at::Tensor &grad_output, const int im2col_step); diff --git a/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh b/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh new file mode 100644 index 0000000..b551ba3 --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh @@ -0,0 +1,1045 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include +#include + +#include +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 256; +inline int GET_BLOCKS(const int N, const int num_threads) { + return (N + num_threads - 1) / num_threads; +} + +#define opmath_t at::opmath_type + +template +__device__ opmath_t dcnv3_im2col_bilinear(const scalar_t *&bottom_data, + const int &height, const int &width, + const int &group, + const int &group_channels, + const opmath_t &h, const opmath_t &w, + const int &g, const int &c) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = group * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = g * group_channels + c; + + opmath_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + opmath_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + opmath_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + opmath_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ void dcnv3_col2im_bilinear( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &group_channels, const opmath_t &h, + const opmath_t &w, const int &m, const int &c, const opmath_t offset_scale, + const opmath_t &top_grad, const opmath_t &mask, opmath_t *&grad_im, + opmath_t *grad_offset, opmath_t *grad_mask) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * group_channels + c; + + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const opmath_t top_grad_im = top_grad * mask; + opmath_t grad_h_weight = 0, grad_w_weight = 0; + + opmath_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_im + ptr1, w1 * top_grad_im); + } + opmath_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_im + ptr2, w2 * top_grad_im); + } + opmath_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_im + ptr3, w3 * top_grad_im); + } + opmath_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_im + ptr4, w4 * top_grad_im); + } + + const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_mask = top_grad * val; + *grad_offset = offset_scale * grad_w_weight * top_grad_im; + *(grad_offset + 1) = offset_scale * grad_h_weight * top_grad_im; +} + +template +__device__ void dcnv3_col2im_bilinear_gm( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &group_channels, const opmath_t &h, + const opmath_t &w, const int &m, const int &c, const opmath_t offset_scale, + const opmath_t &top_grad, const opmath_t &mask, opmath_t *&grad_im, + opmath_t *grad_offset, opmath_t *grad_mask) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * group_channels + c; + + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const opmath_t top_grad_im = top_grad * mask; + opmath_t grad_h_weight = 0, grad_w_weight = 0; + + opmath_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_im + ptr1, w1 * top_grad_im); + } + opmath_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_im + ptr2, w2 * top_grad_im); + } + opmath_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_im + ptr3, w3 * top_grad_im); + } + opmath_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_im + ptr4, w4 * top_grad_im); + } + + const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_mask, top_grad * val); + atomicAdd(grad_offset, offset_scale * grad_w_weight * top_grad_im); + atomicAdd(grad_offset + 1, offset_scale * grad_h_weight * top_grad_im); +} + +template +__global__ void dcnv3_im2col_gpu_kernel( + const int num_kernels, const scalar_t *data_im, const scalar_t *data_offset, + const scalar_t *data_mask, scalar_t *data_col, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale) { + CUDA_KERNEL_LOOP(index, num_kernels) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const int input_size = height_in * width_in; + scalar_t *data_col_ptr = data_col + index; + const int kernel_size = kernel_h * kernel_w; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = group * group_channels; + opmath_t col = 0; + const scalar_t *data_im_ptr = data_im + b_col * input_size * qid_stride; + // top-left + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + col += dcnv3_im2col_bilinear( + data_im_ptr, height_in, width_in, group, + group_channels, loc_h, loc_w, g_col, c_col) * + weight; + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +// debug +template +__global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + __shared__ opmath_t cache_grad_offset[blockSize * 2]; + __shared__ opmath_t cache_grad_mask[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + opmath_t _grad_w = cache_grad_offset[0], + _grad_h = cache_grad_offset[1], + _grad_a = cache_grad_mask[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockSize; ++tid) { + _grad_w += cache_grad_offset[sid]; + _grad_h += cache_grad_offset[sid + 1]; + _grad_a += cache_grad_mask[tid]; + sid += 2; + } + + *grad_offset = _grad_w; + *(grad_offset + 1) = _grad_h; + *grad_mask = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + __shared__ opmath_t cache_grad_offset[blockSize * 2]; + __shared__ opmath_t cache_grad_mask[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_mask[tid] += cache_grad_mask[tid + s]; + cache_grad_offset[xid1] += cache_grad_offset[xid2]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) { + *grad_offset = cache_grad_offset[0]; + *(grad_offset + 1) = cache_grad_offset[1]; + *grad_mask = cache_grad_mask[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + extern __shared__ int _s[]; + opmath_t *cache_grad_offset = (opmath_t *)_s; + opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + opmath_t _grad_w = cache_grad_offset[0], + _grad_h = cache_grad_offset[1], + _grad_a = cache_grad_mask[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) { + _grad_w += cache_grad_offset[sid]; + _grad_h += cache_grad_offset[sid + 1]; + _grad_a += cache_grad_mask[tid]; + sid += 2; + } + + *grad_offset = _grad_w; + *(grad_offset + 1) = _grad_h; + *grad_mask = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + extern __shared__ int _s[]; + opmath_t *cache_grad_offset = (opmath_t *)_s; + opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_mask[tid] += cache_grad_mask[tid + s]; + cache_grad_offset[xid1] += cache_grad_offset[xid2]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_mask[tid] += + cache_grad_mask[tid + (s << 1)]; + cache_grad_offset[xid1] += + cache_grad_offset[xid2 + (s << 1)]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *grad_offset = cache_grad_offset[0]; + *(grad_offset + 1) = cache_grad_offset[1]; + *grad_mask = cache_grad_mask[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + extern __shared__ int _s[]; + opmath_t *cache_grad_offset = (opmath_t *)_s; + opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_mask[tid] += cache_grad_mask[tid + s]; + cache_grad_offset[xid1] += cache_grad_offset[xid2]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_mask[tid] += + cache_grad_mask[tid + (s << 1)]; + cache_grad_offset[xid1] += + cache_grad_offset[xid2 + (s << 1)]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(grad_offset, cache_grad_offset[0]); + atomicAdd(grad_offset + 1, cache_grad_offset[1]); + atomicAdd(grad_mask, cache_grad_mask[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_gm( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear_gm( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, grad_offset, grad_mask); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } +} + +template +void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + scalar_t *data_col, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, + const int width_in, const int height_out, + const int width_out, const opmath_t offset_scale) { + const int num_kernels = + batch_n * height_out * width_out * group * group_channels; + const int num_actual_kernels = + batch_n * height_out * width_out * group * group_channels; + const int num_threads = CUDA_NUM_THREADS; + dcnv3_im2col_gpu_kernel + <<>>(num_kernels, data_im, data_offset, data_mask, data_col, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, height_in, + width_in, height_out, width_out, offset_scale); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in dcnv3_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +template +void dcnv3_col2im_cuda( + cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int batch_n, + const int height_in, const int width_in, const int height_out, + const int width_out, const opmath_t offset_scale, opmath_t *grad_im, + opmath_t *grad_offset, opmath_t *grad_mask) { + const int num_threads = + (group_channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : group_channels; + const int num_kernels = + batch_n * height_out * width_out * group * group_channels; + const int num_actual_kernels = + batch_n * height_out * width_out * group * group_channels; + if (group_channels > 1024) { + if ((group_channels & 1023) == 0) { + dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, grad_col, data_im, data_offset, data_mask, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, height_in, + width_in, height_out, width_out, offset_scale, grad_im, + grad_offset, grad_mask); + } else { + dcnv3_col2im_gpu_kernel_gm + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + } + } else { + switch (group_channels) { + case 1: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 2: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 4: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 8: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 16: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 32: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 64: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 128: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 256: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 512: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + case 1024: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, grad_im, grad_offset, + grad_mask); + break; + default: + if (group_channels < 64) { + dcnv3_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, grad_col, data_im, data_offset, data_mask, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, + height_in, width_in, height_out, width_out, + offset_scale, grad_im, grad_offset, grad_mask); + } else { + dcnv3_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, grad_col, data_im, data_offset, data_mask, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, + height_in, width_in, height_out, width_out, + offset_scale, grad_im, grad_offset, grad_mask); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in dcnv3_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/oneformer/modeling/backbone/ops_dcnv3/src/dcnv3.h b/oneformer/modeling/backbone/ops_dcnv3/src/dcnv3.h new file mode 100644 index 0000000..029648e --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/src/dcnv3.h @@ -0,0 +1,59 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/dcnv3_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/dcnv3_cuda.h" +#endif + +at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const float offset_scale, const int im2col_step) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +dcnv3_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const float offset_scale, const at::Tensor &grad_output, + const int im2col_step) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/oneformer/modeling/backbone/ops_dcnv3/src/vision.cpp b/oneformer/modeling/backbone/ops_dcnv3/src/vision.cpp new file mode 100644 index 0000000..1f7a908 --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/src/vision.cpp @@ -0,0 +1,17 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "dcnv3.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dcnv3_forward", &dcnv3_forward, "dcnv3_forward"); + m.def("dcnv3_backward", &dcnv3_backward, "dcnv3_backward"); +} diff --git a/oneformer/modeling/backbone/ops_dcnv3/test.py b/oneformer/modeling/backbone/ops_dcnv3/test.py new file mode 100644 index 0000000..0277bef --- /dev/null +++ b/oneformer/modeling/backbone/ops_dcnv3/test.py @@ -0,0 +1,263 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +import math +from torch.autograd import gradcheck + +from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch + +H_in, W_in = 8, 8 +N, M, D = 2, 4, 16 +Kh, Kw = 3, 3 +P = Kh * Kw +offset_scale = 2.0 +pad = 1 +dilation = 1 +stride = 1 +H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 +W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 + offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 + mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask /= mask.sum(-1, keepdim=True) + mask = mask.reshape(N, H_out, W_out, M*P) + + output_pytorch = dcnv3_core_pytorch( + input.double(), + offset.double(), + mask.double(), + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu() + + im2col_step = 2 + output_cuda = DCNv3Function.apply( + input.double(), + offset.double(), + mask.double(), + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, + im2col_step).detach().cpu() + + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / + output_pytorch.abs()).max() + print('>>> forward double') + print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 + offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 + mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask /= mask.sum(-1, keepdim=True) + mask = mask.reshape(N, H_out, W_out, M*P) + + output_pytorch = dcnv3_core_pytorch( + input, + offset, + mask, + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu() + + im2col_step = 2 + output_cuda = DCNv3Function.apply( + input, + offset, + mask, + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, + im2col_step).detach().cpu() + + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / + output_pytorch.abs()).max() + print('>>> forward float') + print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True): + # H_in, W_in = 4, 4 + N = 2 + M = 2 + H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 + W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 + + D = channels + input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 + offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 + mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask0 /= mask0.sum(-1, keepdim=True) + mask0 = mask0.reshape(N, H_out, W_out, M*P) + input0.requires_grad = grad_input + offset0.requires_grad = grad_offset + mask0.requires_grad = grad_mask + + output_pytorch = dcnv3_core_pytorch( + input0.double(), + offset0.double(), + mask0.double(), + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale) + output_pytorch.sum().backward() + + input1 = input0.detach() + offset1 = offset0.detach() + mask1 = mask0.detach() + input1.requires_grad = grad_input + offset1.requires_grad = grad_offset + mask1.requires_grad = grad_mask + + im2col_step = 2 + output_cuda = DCNv3Function.apply( + input1.double(), + offset1.double(), + mask1.double(), + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, + im2col_step) + output_cuda.sum().backward() + + print(f'>>> backward double: channels {D}') + bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (input0.grad - input1.grad).abs().max() + max_rel_err = ((input0.grad - input1.grad).abs() / + input0.grad.abs()).max() + print( + f'* {bwdok} input_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (offset0.grad - offset1.grad).abs().max() + max_rel_err = ((offset0.grad - offset1.grad).abs() / + offset0.grad.abs()).max() + print( + f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (mask0.grad - mask1.grad).abs().max() + max_rel_err = ((mask0.grad - mask1.grad).abs() / + mask0.grad.abs()).max() + print( + f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_offset=True, grad_mask=True): + # H_in, W_in = 4, 4 + N = 2 + M = 2 + H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 + W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 + + D = channels + input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 + offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 + mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask0 /= mask0.sum(-1, keepdim=True) + mask0 = mask0.reshape(N, H_out, W_out, M*P) + input0.requires_grad = grad_input + offset0.requires_grad = grad_offset + mask0.requires_grad = grad_mask + + output_pytorch = dcnv3_core_pytorch( + input0, + offset0, + mask0, + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale) + output_pytorch.sum().backward() + + input1 = input0.detach() + offset1 = offset0.detach() + mask1 = mask0.detach() + input1.requires_grad = grad_input + offset1.requires_grad = grad_offset + mask1.requires_grad = grad_mask + + im2col_step = 2 + output_cuda = DCNv3Function.apply( + input1, + offset1, + mask1, + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, + im2col_step) + output_cuda.sum().backward() + + print(f'>>> backward float: channels {D}') + bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (input0.grad - input1.grad).abs().max() + max_rel_err = ((input0.grad - input1.grad).abs() / + input0.grad.abs()).max() + print( + f'* {bwdok} input_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (offset0.grad - offset1.grad).abs().max() + max_rel_err = ((offset0.grad - offset1.grad).abs() / + offset0.grad.abs()).max() + print( + f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (mask0.grad - mask1.grad).abs().max() + max_rel_err = ((mask0.grad - mask1.grad).abs() / + mask0.grad.abs()).max() + print( + f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_time_cost(im2col_step=128): + N = 512 + H_in, W_in = 64, 64 + H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 + W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 + + input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 + offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 + mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask /= mask.sum(-1, keepdim=True) + mask = mask.reshape(N, H_out, W_out, M*P) + print( + f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ') + repeat = 100 + for i in range(repeat): + output_cuda = DCNv3Function.apply( + input, + offset, + mask, + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, + im2col_step) + torch.cuda.synchronize() + start = time.time() + for i in range(repeat): + output_cuda = DCNv3Function.apply( + input, + offset, + mask, + Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, + im2col_step) + torch.cuda.synchronize() + print(f'foward time cost: {(time.time() - start) / repeat}') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + for channels in [1, 16, 30, 32, 64, 71, 1025]: + check_backward_equal_with_pytorch_double(channels, True, True, True) + for channels in [1, 16, 30, 32, 64, 71, 1025]: + check_backward_equal_with_pytorch_float(channels, True, True, True) + for i in range(3): + im2col_step = 128 * (2 ** i) + check_time_cost(im2col_step) diff --git a/tools/README.md b/tools/README.md index 071532a..65d4de3 100644 --- a/tools/README.md +++ b/tools/README.md @@ -54,6 +54,19 @@ It's common to initialize from backbone models pre-trained on ImageNet classific +
+InternImage + +- [Official Repo](https://github.com/OpenGVLab/InternImage) +- `convert-pretrained-model-to-d2.py`: Tool to convert InternImage pre-trained weights for D2. + + ```bash + wget https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_h_jointto22k_384.pth + python tools/convert-pretrained-model-to-d2.py internimage_h_jointto22k_384.pth internimage_h_jointto22k_384.pkl + ``` + +
+ ## Analyze Model - Tool to analyze model parameters, flops and speed. @@ -76,7 +89,7 @@ python tools/analyze_model.py --num-inputs 100 --tasks [flop speed] \ python tools/calc_throughput.py --dist-url 'tcp://127.0.0.1:50162' \ --num-gpus 8 \ --config-file configs/ade20k/swin/oneformer_swin_large_IN21k_384_bs16_160k.yaml \ -MODEL.WEIGHTS pretrain/swin_large_patch4_window12_384_22kto1k.pkl \ +MODEL.WEIGHTS swin_large_patch4_window12_384_22kto1k.pkl \ OUTPUT_DIR tp_out SOLVER.MAX_ITER 500 rm -rf tp_out diff --git a/train_net.py b/train_net.py index 1507b25..6858f02 100644 --- a/train_net.py +++ b/train_net.py @@ -60,6 +60,7 @@ add_swin_config, add_dinat_config, add_convnext_config, + add_internimage_config, ) from detectron2.utils.events import CommonMetricPrinter, JSONWriter @@ -391,6 +392,7 @@ def setup(args): add_swin_config(cfg) add_dinat_config(cfg) add_convnext_config(cfg) + add_internimage_config(cfg) add_oneformer_config(cfg) cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) From 97ff452808984420f47b17d8c13670ea5432e21f Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Thu, 8 Jun 2023 23:09:57 +0530 Subject: [PATCH 2/9] :hammer: Fix dim_feedforward for 1024 emb_dim --- oneformer/modeling/pixel_decoder/msdeformattn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/oneformer/modeling/pixel_decoder/msdeformattn.py b/oneformer/modeling/pixel_decoder/msdeformattn.py index 9b6d0aa..ba6aee0 100644 --- a/oneformer/modeling/pixel_decoder/msdeformattn.py +++ b/oneformer/modeling/pixel_decoder/msdeformattn.py @@ -305,9 +305,11 @@ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM ret["transformer_dropout"] = cfg.MODEL.ONE_FORMER.DROPOUT - ret["transformer_nheads"] = cfg.MODEL.ONE_FORMER.NHEADS - # ret["transformer_dim_feedforward"] = cfg.MODEL.ONE_FORMER.DIM_FEEDFORWARD - ret["transformer_dim_feedforward"] = 1024 # use 1024 for deformable transformer encoder + ret["transformer_nheads"] = cfg.MODEL.ONE_FORMER.NHEADS + if cfg.MODEL.SEM_SEG_HEAD.MASK_DIM != 256: + ret["transformer_dim_feedforward"] = cfg.MODEL.ONE_FORMER.DIM_FEEDFORWARD + else: + ret["transformer_dim_feedforward"] = 1024 ret[ "transformer_enc_layers" ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config From 95d604cb18bc104873ff1592f1f983d588cf7a15 Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Sat, 10 Jun 2023 15:02:37 +0530 Subject: [PATCH 3/9] :zap: Update Readme --- README.md | 3 +- ...ern_image_huge_bs16_160k_896x896_1024.yaml | 7 ++--- ...ormer_intern_image_huge_bs16_90k_1024.yaml | 31 ++++++++++++++++--- ...mer_intern_image_huge_bs16_100ep_1024.yaml | 5 ++- demo/predictor.py | 6 ++-- 5 files changed, 37 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 21a37e8..2de87f0 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Equal Contribution -[[`Project Page`](https://praeclarumjj3.github.io/oneformer/)] [[`arXiv`](https://arxiv.org/abs/2211.06220)] [[`pdf`](https://arxiv.org/pdf/2211.06220.pdf)] [[`BibTeX`](#4citation)] +[[`Project Page`](https://praeclarumjj3.github.io/oneformer/)] [[`arXiv`](https://arxiv.org/abs/2211.06220)] [[`pdf`](https://openaccess.thecvf.com/content/CVPR2023/papers/Jain_OneFormer_One_Transformer_To_Rule_Universal_Image_Segmentation_CVPR_2023_paper.pdf)] [[`Slides`](https://drive.google.com/file/d/12XhiOXD08_LwzBwosoLVk7i8D45V8YfW/view?usp=sharing)] [[`Poster`](https://drive.google.com/file/d/1-U3hCYVNVht26NM-zbE87p1V4idc5bCt/view?usp=sharing)] [[`BibTeX`](#4citation)] This repo contains the code for our paper **OneFormer: One Transformer to Rule Universal Image Segmentation**. @@ -38,6 +38,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U ## News +- **[June 10, 2023]**: OneFormer achieves SOTA performance on ADE20K panoptic segmentation with **54.5 PQ** and on Cityscapes instance segmentation with **50.6 AP** scores. We release the corresponding models with InternImage-H backbone publicly! - **[February 27, 2023]**: OneFormer is accepted to CVPR 2023! - **[January 26, 2023]**: OneFormer sets new SOTA performance on the the Mapillary Vistas val (both panoptic & semantic segmentation) and Cityscapes test (panoptic segmentation) sets. We’ve released the checkpoints too! - **[January 19, 2023]**: OneFormer is now available as a part of the 🤗 **HuggingFace [transformers](https://huggingface.co/docs/transformers/main/en/model_doc/oneformer) library** and **[model hub](https://huggingface.co/models?filter=oneformer)**! 🚀 diff --git a/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml b/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml index 1c1317c..50d4e4a 100644 --- a/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml +++ b/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml @@ -3,10 +3,6 @@ MODEL: BACKBONE: NAME: "D2InternImage" SEM_SEG_HEAD: - NAME: "OneFormerHead" - IGNORE_VALUE: 255 - NUM_CLASSES: 150 - LOSS_WEIGHT: 1.0 CONVS_DIM: 1024 MASK_DIM: 1024 INTERNIMAGE: @@ -51,3 +47,6 @@ TEST: MIN_SIZES: [448, 678, 896, 1120, 1344, 1568] MAX_SIZE: 6272 FLIP: True +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.00002 diff --git a/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml b/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml index 66013af..fe55ab5 100644 --- a/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml +++ b/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml @@ -3,10 +3,6 @@ MODEL: BACKBONE: NAME: "D2InternImage" SEM_SEG_HEAD: - NAME: "OneFormerHead" - IGNORE_VALUE: 255 - NUM_CLASSES: 150 - LOSS_WEIGHT: 1.0 CONVS_DIM: 1024 MASK_DIM: 1024 INTERNIMAGE: @@ -30,4 +26,29 @@ MODEL: CONTEXT_LENGTH: 77 N_CTX: 16 TEST: - DETECTIONS_PER_IMAGE: 250 \ No newline at end of file + DETECTIONS_PER_IMAGE: 250 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 896) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 896 + MAX_SIZE_TRAIN: 3584 + MAX_SIZE_TEST: 3584 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (896, 896) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 896 # used in dataset mapper + FORMAT: "RGB" +TEST: + DETECTIONS_PER_IMAGE: 250 + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [448, 678, 896, 1120, 1344, 1568] + MAX_SIZE: 6272 + FLIP: True +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.00002 \ No newline at end of file diff --git a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml index e6c9ba1..748d71f 100644 --- a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml +++ b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml @@ -4,9 +4,6 @@ MODEL: NAME: "D2InternImage" SEM_SEG_HEAD: NAME: "OneFormerHead" - IGNORE_VALUE: 255 - NUM_CLASSES: 150 - LOSS_WEIGHT: 1.0 CONVS_DIM: 1024 MASK_DIM: 1024 INTERNIMAGE: @@ -30,6 +27,8 @@ MODEL: CONTEXT_LENGTH: 77 N_CTX: 16 SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.00002 STEPS: (655556, 735184) MAX_ITER: 737500 AMP: diff --git a/demo/predictor.py b/demo/predictor.py index 76e32bb..f012155 100644 --- a/demo/predictor.py +++ b/demo/predictor.py @@ -52,6 +52,8 @@ def run_on_image(self, image, task): # Convert image from OpenCV BGR format to Matplotlib RGB format. image = image[:, :, ::-1] vis_output = {} + + assert task in ['panoptic', 'semantic', 'instance'], "task should be one of 'panoptic', 'semantic', 'instance'" if task == 'panoptic': visualizer = Visualizer(image, metadata=self.metadata, instance_mode=ColorMode.IMAGE) @@ -61,14 +63,14 @@ def run_on_image(self, image, task): panoptic_seg.to(self.cpu_device), segments_info, alpha=0.7 ) - if task == 'panoptic' or task == 'semantic': + if task == 'semantic': visualizer = Visualizer(image, metadata=self.metadata, instance_mode=ColorMode.IMAGE_BW) predictions = self.predictor(image, task) vis_output['semantic_inference'] = visualizer.draw_sem_seg( predictions["sem_seg"].argmax(dim=0).to(self.cpu_device), alpha=0.7 ) - if task == 'panoptic' or task == 'instance': + if task == 'instance': visualizer = Visualizer(image, metadata=self.metadata, instance_mode=ColorMode.IMAGE_BW) predictions = self.predictor(image, task) instances = predictions["instances"].to(self.cpu_device) From 1b3efa1070d90a0c108182b7301dd2bec5187a48 Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Fri, 16 Jun 2023 00:42:18 +0530 Subject: [PATCH 4/9] :zap: Add Mapillary Intern Image --- README.md | 1 + ...ern_image_huge_bs16_160k_896x896_1024.yaml | 2 +- ...ormer_intern_image_huge_bs16_90k_1024.yaml | 2 +- ...mer_intern_image_huge_bs16_100ep_1024.yaml | 2 +- ...oneformer_intern_image_huge_bs16_300k.yaml | 19 +++++++++++++++++++ tools/analyze_model.py | 2 -- 6 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml diff --git a/README.md b/README.md index 2de87f0..3ec7baa 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer | Swin-L | 46.7 | 62.9 | 64.1 | 219M | [config](configs/mapillary_vistas/swin/oneformer_swin_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_swin_l_oneformer_mapillary_300k.pth) | | OneFormer | ConvNeXt-L | 47.9 | 63.2 | 63.8 | 220M | [config](configs/mapillary_vistas/convnext/oneformer_convnext_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_convnext_l_oneformer_mapillary_300k.pth) | | OneFormer | DiNAT-L | 47.8 | 64.0 | 64.9 | 223M | [config](configs/mapillary_vistas/dinat/oneformer_dinat_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_dinat_l_oneformer_mapillary_300k.pth) | +| OneFormer | InternImage-H | 51.7 | 65.5 | 66.5 | 1.10B | [config](configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_intern_image_h_oneformer_mapillary_300k.pth) | ## Citation diff --git a/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml b/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml index 50d4e4a..43d62a5 100644 --- a/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml +++ b/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml @@ -49,4 +49,4 @@ TEST: FLIP: True SOLVER: IMS_PER_BATCH: 16 - BASE_LR: 0.00002 + BASE_LR: 0.00004 diff --git a/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml b/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml index fe55ab5..9ce3dc4 100644 --- a/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml +++ b/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml @@ -51,4 +51,4 @@ TEST: FLIP: True SOLVER: IMS_PER_BATCH: 16 - BASE_LR: 0.00002 \ No newline at end of file + BASE_LR: 0.00004 \ No newline at end of file diff --git a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml index 748d71f..380360f 100644 --- a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml +++ b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml @@ -28,7 +28,7 @@ MODEL: N_CTX: 16 SOLVER: IMS_PER_BATCH: 16 - BASE_LR: 0.00002 + BASE_LR: 0.00005 STEPS: (655556, 735184) MAX_ITER: 737500 AMP: diff --git a/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml b/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml new file mode 100644 index 0000000..945b7ac --- /dev/null +++ b/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml @@ -0,0 +1,19 @@ +_BASE_: ../oneformer_R50_bs16_300k.yaml +MODEL: + BACKBONE: + NAME: "D2InternImage" + INTERNIMAGE: + CHANNELS: 320 + DEPTHS: [6, 6, 32, 6] + GROUPS: [10, 20, 40, 80] + WITH_CP: True + MLP_RATIO: 4.0 + DW_KERNEL_SIZE: 5 + LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] + WEIGHTS: "internimage_h_jointto22k_384.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + NUM_OBJECT_QUERIES: 250 +TEST: + DETECTIONS_PER_IMAGE: 250 \ No newline at end of file diff --git a/tools/analyze_model.py b/tools/analyze_model.py index 4a57c20..b28daf0 100644 --- a/tools/analyze_model.py +++ b/tools/analyze_model.py @@ -29,7 +29,6 @@ add_common_config, add_swin_config, add_dinat_config, - add_beit_adapter_config, add_convnext_config, ) @@ -43,7 +42,6 @@ def setup(args): add_common_config(cfg) add_swin_config(cfg) add_dinat_config(cfg) - add_beit_adapter_config(cfg) add_oneformer_config(cfg) add_convnext_config(cfg) cfg.merge_from_file(args.config_file) From 8886a22e4aade03e01efad33cd27f1868788120e Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Mon, 3 Jul 2023 13:28:04 +0530 Subject: [PATCH 5/9] :zap: Update Readme with COCO --- README.md | 4 ++- ...mer_intern_image_huge_bs16_100ep_1024.yaml | 2 +- ...rmer_intern_image_huge_bs16_300k_1024.yaml | 32 +++++++++++++++++++ demo/README.md | 2 +- 4 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k_1024.yaml diff --git a/README.md b/README.md index 3ec7baa..879e80d 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | :---:| :---: | :---: | :---: | :---: |:---:| :---:| :---: | :---: | :---: | | OneFormer | Swin-L | 57.9 | 64.4 | 48.0 | 49.0 | 67.4 | 219M | [config](configs/coco/swin/oneformer_swin_large_bs16_100ep.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/150_16_swin_l_oneformer_coco_100ep.pth) | | OneFormer | DiNAT-L | 58.0 | 64.3 | 48.4 | 49.2 | 68.1 | 223M | [config](configs/coco/dinat/oneformer_dinat_large_bs16_100ep.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/150_16_dinat_l_oneformer_coco_100ep.pth) | +| OneFormer | InternImage-H | 59.1 | 66.4 | 48.1 | 50.5 | 68.1 | 1.10B | [config](configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_intern_image_h_oneformer_coco_100ep.pth) | ### Mapillary Vistas @@ -126,7 +127,8 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer | Swin-L | 46.7 | 62.9 | 64.1 | 219M | [config](configs/mapillary_vistas/swin/oneformer_swin_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_swin_l_oneformer_mapillary_300k.pth) | | OneFormer | ConvNeXt-L | 47.9 | 63.2 | 63.8 | 220M | [config](configs/mapillary_vistas/convnext/oneformer_convnext_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_convnext_l_oneformer_mapillary_300k.pth) | | OneFormer | DiNAT-L | 47.8 | 64.0 | 64.9 | 223M | [config](configs/mapillary_vistas/dinat/oneformer_dinat_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_dinat_l_oneformer_mapillary_300k.pth) | -| OneFormer | InternImage-H | 51.7 | 65.5 | 66.5 | 1.10B | [config](configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_intern_image_h_oneformer_mapillary_300k.pth) | +| OneFormer (emb_dim=256) | InternImage-H | 51.7 | 65.5 | 66.5 | 1.10B | [config](configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_intern_image_h_oneformer_mapillary_300k.pth) | +| OneFormer (emb_dim=1024) | InternImage-H | 51.7 | 66.6 | 67.3 | 1.34B | [config](configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k_1024.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_intern_image_h_oneformer_mapillary_300k_1024.pth) | ## Citation diff --git a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml index 380360f..68bdd1a 100644 --- a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml +++ b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml @@ -28,7 +28,7 @@ MODEL: N_CTX: 16 SOLVER: IMS_PER_BATCH: 16 - BASE_LR: 0.00005 + BASE_LR: 0.00004 STEPS: (655556, 735184) MAX_ITER: 737500 AMP: diff --git a/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k_1024.yaml b/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k_1024.yaml new file mode 100644 index 0000000..63867c4 --- /dev/null +++ b/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k_1024.yaml @@ -0,0 +1,32 @@ +_BASE_: ../oneformer_R50_bs16_300k.yaml +MODEL: + BACKBONE: + NAME: "D2InternImage" + SEM_SEG_HEAD: + CONVS_DIM: 1024 + MASK_DIM: 1024 + INTERNIMAGE: + CHANNELS: 320 + DEPTHS: [6, 6, 32, 6] + GROUPS: [10, 20, 40, 80] + WITH_CP: True + MLP_RATIO: 4.0 + DW_KERNEL_SIZE: 5 + LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] + WEIGHTS: "pretrain/internimage_h_jointto22k_384.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + HIDDEN_DIM: 1024 + NUM_OBJECT_QUERIES: 250 + NHEADS: 32 + DIM_FEEDFORWARD: 4096 + TEXT_ENCODER: + WIDTH: 1024 + CONTEXT_LENGTH: 77 + N_CTX: 16 +TEST: + DETECTIONS_PER_IMAGE: 250 +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.00002 \ No newline at end of file diff --git a/demo/README.md b/demo/README.md index dfa5f4e..c1fedcf 100644 --- a/demo/README.md +++ b/demo/README.md @@ -1,6 +1,6 @@ # OneFormer Demo -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SHI-Labs/FcF-Inpainting/blob/main/colab/FcF_Inpainting.ipynb) [![Huggingface space](https://img.shields.io/badge/🤗-Huggingface%20Space-cyan.svg)](https://huggingface.co/spaces/shi-labs/OneFormer) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SHI-Labs/OneFormer/blob/main/colab/oneformer_colab.ipynb) [![Huggingface space](https://img.shields.io/badge/🤗-Huggingface%20Space-cyan.svg)](https://huggingface.co/spaces/shi-labs/OneFormer) - Pick a model and its config file from. For example, `configs/ade20k/swin/oneformer_swin_large_IN21k_384_bs16_160k.yaml`. - We provide `demo.py` that is able to demo builtin configs. From e38d63a7ff42bcd0f8c6428cec7035578d92eae0 Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Thu, 6 Jul 2023 17:08:37 +0530 Subject: [PATCH 6/9] :zap: Add COCO checkpoint --- README.md | 7 +++--- ...rn_image_huge_bs16_160k_896x896_1024.yaml} | 4 +++- ...rmer_intern_image_huge_bs16_90k_1024.yaml} | 2 +- ...neformer_intern_image_huge_bs16_100ep.yaml | 24 ------------------- ...oneformer_intern_image_huge_bs16_300k.yaml | 19 --------------- 5 files changed, 7 insertions(+), 49 deletions(-) rename configs/ade20k/intern_image/{oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml => coco_pretrain_oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml} (92%) rename configs/cityscapes/intern_image/{oneformer_intern_image_huge_bs16_90k_1024.yaml => mapillary_pretrain_oneformer_intern_image_huge_bs16_90k_1024.yaml} (93%) delete mode 100644 configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml delete mode 100644 configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml diff --git a/README.md b/README.md index 879e80d..19229e5 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer | DiNAT-L | 1280×1280 | 51.5 | 37.1 | 58.3 | 58.7 | 223M | [config](configs/ade20k/dinat/oneformer_dinat_large_bs16_160k_1280x1280.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/1280x1280_250_16_dinat_l_oneformer_ade20k_160k.pth) | | OneFormer (COCO-Pretrained) | DiNAT-L | 1280×1280 | 53.4 | 40.2 | 58.4 | 58.8 | 223M | [config](configs/ade20k/dinat/coco_pretrain_oneformer_dinat_large_bs16_160k_1280x1280_coco_pretrain.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/coco_pretrain_1280x1280_150_16_dinat_l_oneformer_ade20k_160k.pth) | [pretrained](https://shi-labs.com/projects/oneformer/coco/150_16_dinat_l_oneformer_coco_100ep.pth) | | OneFormer | ConvNeXt-XL | 640×640 | 50.1 | 36.3 | 57.4 | 58.8 | 372M | [config](configs/ade20k/convnext/oneformer_convnext_xlarge_bs16_160k.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/250_16_convnext_xl_oneformer_ade20k_160k.pth) | -| OneFormer | InternImage-H | 896×896 | 54.5 | 40.2 | 60.4 | 60.8 | 1.10B | [config](configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/896x896_250_16_intern_image_h_oneformer_ade20k_160k.pth) | +| OneFormer (emb_dim=256) | InternImage-H | 896×896 | 54.5 | 40.2 | 60.4 | 60.8 | 1.10B | [config](configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/896x896_250_16_intern_image_h_oneformer_ade20k_160k.pth) | ### Cityscapes @@ -110,7 +110,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer | DiNAT-L | 67.6 | 45.6 | 83.1 | 84.0 | 223M | [config](configs/cityscapes/dinat/oneformer_dinat_large_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_dinat_l_oneformer_cityscapes_90k.pth) | | OneFormer | ConvNeXt-XL | 68.4 | 46.7 | 83.6 | 84.6 | 372M | [config](configs/cityscapes/convnext/oneformer_convnext_xlarge_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_convnext_xl_oneformer_cityscapes_90k.pth) | | OneFormer (Mapillary Vistas-Pretrained) | ConvNeXt-XL | 69.7 | 48.9 | 84.5 | 85.8 | 372M | [config](configs/cityscapes/convnext/mapillary_pretrain_oneformer_convnext_xlarge_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/mapillary_pretrain_250_16_convnext_xl_oneformer_cityscapes_90k.pth) | [pretrained](https://shi-labs.com/projects/oneformer/mapillary/mapillary_pretrain_250_16_convnext_xl_oneformer_mapillary_300k.pth) | -| OneFormer | InternImage-H | 70.6 | 50.6 | 85.1 | 85.7 | 1.10B | [config](configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_intern_image_h_oneformer_cityscapes_90k.pth) | +| OneFormer (emb_dim=256) | InternImage-H | 70.6 | 50.6 | 85.1 | 85.7 | 1.10B | [config](configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_intern_image_h_oneformer_cityscapes_90k.pth) | ### COCO @@ -118,7 +118,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | :---:| :---: | :---: | :---: | :---: |:---:| :---:| :---: | :---: | :---: | | OneFormer | Swin-L | 57.9 | 64.4 | 48.0 | 49.0 | 67.4 | 219M | [config](configs/coco/swin/oneformer_swin_large_bs16_100ep.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/150_16_swin_l_oneformer_coco_100ep.pth) | | OneFormer | DiNAT-L | 58.0 | 64.3 | 48.4 | 49.2 | 68.1 | 223M | [config](configs/coco/dinat/oneformer_dinat_large_bs16_100ep.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/150_16_dinat_l_oneformer_coco_100ep.pth) | -| OneFormer | InternImage-H | 59.1 | 66.4 | 48.1 | 50.5 | 68.1 | 1.10B | [config](configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml) | [model](https://shi-labs.com/projects/oneformer/cityscapes/250_16_intern_image_h_oneformer_coco_100ep.pth) | +| OneFormer (emb_dim=1024) | InternImage-H | 60.0 | 67.1 | 49.2 | 52.0 | 68.8 | 1.34B | [config](configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/250_16_intern_image_h_oneformer_coco_100ep_1024.pth) | ### Mapillary Vistas @@ -127,7 +127,6 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer | Swin-L | 46.7 | 62.9 | 64.1 | 219M | [config](configs/mapillary_vistas/swin/oneformer_swin_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_swin_l_oneformer_mapillary_300k.pth) | | OneFormer | ConvNeXt-L | 47.9 | 63.2 | 63.8 | 220M | [config](configs/mapillary_vistas/convnext/oneformer_convnext_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_convnext_l_oneformer_mapillary_300k.pth) | | OneFormer | DiNAT-L | 47.8 | 64.0 | 64.9 | 223M | [config](configs/mapillary_vistas/dinat/oneformer_dinat_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_dinat_l_oneformer_mapillary_300k.pth) | -| OneFormer (emb_dim=256) | InternImage-H | 51.7 | 65.5 | 66.5 | 1.10B | [config](configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_intern_image_h_oneformer_mapillary_300k.pth) | | OneFormer (emb_dim=1024) | InternImage-H | 51.7 | 66.6 | 67.3 | 1.34B | [config](configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k_1024.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_intern_image_h_oneformer_mapillary_300k_1024.pth) | diff --git a/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml b/configs/ade20k/intern_image/coco_pretrain_oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml similarity index 92% rename from configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml rename to configs/ade20k/intern_image/coco_pretrain_oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml index 43d62a5..5722101 100644 --- a/configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml +++ b/configs/ade20k/intern_image/coco_pretrain_oneformer_intern_image_huge_bs16_160k_896x896_1024.yaml @@ -13,7 +13,7 @@ MODEL: MLP_RATIO: 4.0 DW_KERNEL_SIZE: 5 LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] - WEIGHTS: "internimage_h_jointto22k_384.pkl" + WEIGHTS: "250_16_intern_image_h_oneformer_coco_100ep_1024.pth" PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] ONE_FORMER: @@ -50,3 +50,5 @@ TEST: SOLVER: IMS_PER_BATCH: 16 BASE_LR: 0.00004 + AMP: + ENABLED: False diff --git a/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml b/configs/cityscapes/intern_image/mapillary_pretrain_oneformer_intern_image_huge_bs16_90k_1024.yaml similarity index 93% rename from configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml rename to configs/cityscapes/intern_image/mapillary_pretrain_oneformer_intern_image_huge_bs16_90k_1024.yaml index 9ce3dc4..e848f43 100644 --- a/configs/cityscapes/intern_image/oneformer_intern_image_huge_bs16_90k_1024.yaml +++ b/configs/cityscapes/intern_image/mapillary_pretrain_oneformer_intern_image_huge_bs16_90k_1024.yaml @@ -13,7 +13,7 @@ MODEL: MLP_RATIO: 4.0 DW_KERNEL_SIZE: 5 LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] - WEIGHTS: "internimage_h_jointto22k_384.pkl" + WEIGHTS: "mapillary_pretrain_250_16_intern_image_h_oneformer_mapillary_300k_1024.pth" PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] ONE_FORMER: diff --git a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml b/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml deleted file mode 100644 index c7e00e7..0000000 --- a/configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep.yaml +++ /dev/null @@ -1,24 +0,0 @@ -_BASE_: ../oneformer_R50_bs16_50ep.yaml -MODEL: - BACKBONE: - NAME: "D2InternImage" - INTERNIMAGE: - CHANNELS: 320 - DEPTHS: [6, 6, 32, 6] - GROUPS: [10, 20, 40, 80] - WITH_CP: True - MLP_RATIO: 4.0 - DW_KERNEL_SIZE: 5 - LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] - WEIGHTS: "internimage_h_jointto22k_384.pkl" - PIXEL_MEAN: [123.675, 116.280, 103.530] - PIXEL_STD: [58.395, 57.120, 57.375] - ONE_FORMER: - NUM_OBJECT_QUERIES: 250 -SOLVER: - STEPS: (655556, 735184) - MAX_ITER: 737500 - AMP: - ENABLED: False -TEST: - DETECTIONS_PER_IMAGE: 250 diff --git a/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml b/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml deleted file mode 100644 index 945b7ac..0000000 --- a/configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k.yaml +++ /dev/null @@ -1,19 +0,0 @@ -_BASE_: ../oneformer_R50_bs16_300k.yaml -MODEL: - BACKBONE: - NAME: "D2InternImage" - INTERNIMAGE: - CHANNELS: 320 - DEPTHS: [6, 6, 32, 6] - GROUPS: [10, 20, 40, 80] - WITH_CP: True - MLP_RATIO: 4.0 - DW_KERNEL_SIZE: 5 - LEVEL2_POST_NORM_BLOCK_IDS: [5, 11, 17, 23, 29] - WEIGHTS: "internimage_h_jointto22k_384.pkl" - PIXEL_MEAN: [123.675, 116.280, 103.530] - PIXEL_STD: [58.395, 57.120, 57.375] - ONE_FORMER: - NUM_OBJECT_QUERIES: 250 -TEST: - DETECTIONS_PER_IMAGE: 250 \ No newline at end of file From 2d86dc238670ba0f114e88e513dd40c1a59b53ac Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Thu, 6 Jul 2023 17:15:07 +0530 Subject: [PATCH 7/9] :zap: Update News --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 19229e5..a9ffe60 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U ## News -- **[June 10, 2023]**: OneFormer achieves SOTA performance on ADE20K panoptic segmentation with **54.5 PQ** and on Cityscapes instance segmentation with **50.6 AP** scores. We release the corresponding models with InternImage-H backbone publicly! +- **[July 6, 2023]**: OneFormer achieves SOTA performance on COCO panoptic segmentation with **60.0 PQ**, on ADE20K panoptic segmentation with **54.5 PQ** and on Cityscapes instance segmentation with **50.6 AP** scores. We release the corresponding models with InternImage-H backbone publicly! - **[February 27, 2023]**: OneFormer is accepted to CVPR 2023! - **[January 26, 2023]**: OneFormer sets new SOTA performance on the the Mapillary Vistas val (both panoptic & semantic segmentation) and Cityscapes test (panoptic segmentation) sets. We’ve released the checkpoints too! - **[January 19, 2023]**: OneFormer is now available as a part of the 🤗 **HuggingFace [transformers](https://huggingface.co/docs/transformers/main/en/model_doc/oneformer) library** and **[model hub](https://huggingface.co/models?filter=oneformer)**! 🚀 From c33912f32f74a8a6b67b6418fb83797467591075 Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Wed, 30 Aug 2023 17:58:51 -0400 Subject: [PATCH 8/9] :hammer: Fix mapillary 2024 model --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a9ffe60..ee1f4dc 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer (COCO-Pretrained) | DiNAT-L | 1280×1280 | 53.4 | 40.2 | 58.4 | 58.8 | 223M | [config](configs/ade20k/dinat/coco_pretrain_oneformer_dinat_large_bs16_160k_1280x1280_coco_pretrain.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/coco_pretrain_1280x1280_150_16_dinat_l_oneformer_ade20k_160k.pth) | [pretrained](https://shi-labs.com/projects/oneformer/coco/150_16_dinat_l_oneformer_coco_100ep.pth) | | OneFormer | ConvNeXt-XL | 640×640 | 50.1 | 36.3 | 57.4 | 58.8 | 372M | [config](configs/ade20k/convnext/oneformer_convnext_xlarge_bs16_160k.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/250_16_convnext_xl_oneformer_ade20k_160k.pth) | | OneFormer (emb_dim=256) | InternImage-H | 896×896 | 54.5 | 40.2 | 60.4 | 60.8 | 1.10B | [config](configs/ade20k/intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/896x896_250_16_intern_image_h_oneformer_ade20k_160k.pth) | +| OneFormer (emb_dim=1024, COCO-Pretrained) | InternImage-H | 896×896 | 55.5 | 44.2 | 60.7 | 60.7 | 1.35B | [config](configs/ade20k/coco_pretrain_intern_image/oneformer_intern_image_huge_bs16_160k_896x896.yaml) | [model](https://shi-labs.com/projects/oneformer/ade20k/coco_pretrain_896x896_250_16_intern_image_h_oneformer_ade20k_160k.pth) | ### Cityscapes @@ -118,7 +119,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | :---:| :---: | :---: | :---: | :---: |:---:| :---:| :---: | :---: | :---: | | OneFormer | Swin-L | 57.9 | 64.4 | 48.0 | 49.0 | 67.4 | 219M | [config](configs/coco/swin/oneformer_swin_large_bs16_100ep.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/150_16_swin_l_oneformer_coco_100ep.pth) | | OneFormer | DiNAT-L | 58.0 | 64.3 | 48.4 | 49.2 | 68.1 | 223M | [config](configs/coco/dinat/oneformer_dinat_large_bs16_100ep.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/150_16_dinat_l_oneformer_coco_100ep.pth) | -| OneFormer (emb_dim=1024) | InternImage-H | 60.0 | 67.1 | 49.2 | 52.0 | 68.8 | 1.34B | [config](configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/250_16_intern_image_h_oneformer_coco_100ep_1024.pth) | +| OneFormer (emb_dim=1024) | InternImage-H | 60.0 | 67.1 | 49.2 | 52.0 | 68.8 | 1.35B | [config](configs/coco/intern_image/oneformer_intern_image_huge_bs16_100ep_1024.yaml) | [model](https://shi-labs.com/projects/oneformer/coco/250_16_intern_image_h_oneformer_coco_100ep_1024.pth) | ### Mapillary Vistas @@ -127,7 +128,7 @@ This repo contains the code for our paper **OneFormer: One Transformer to Rule U | OneFormer | Swin-L | 46.7 | 62.9 | 64.1 | 219M | [config](configs/mapillary_vistas/swin/oneformer_swin_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_swin_l_oneformer_mapillary_300k.pth) | | OneFormer | ConvNeXt-L | 47.9 | 63.2 | 63.8 | 220M | [config](configs/mapillary_vistas/convnext/oneformer_convnext_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_convnext_l_oneformer_mapillary_300k.pth) | | OneFormer | DiNAT-L | 47.8 | 64.0 | 64.9 | 223M | [config](configs/mapillary_vistas/dinat/oneformer_dinat_large_bs16_300k.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_dinat_l_oneformer_mapillary_300k.pth) | -| OneFormer (emb_dim=1024) | InternImage-H | 51.7 | 66.6 | 67.3 | 1.34B | [config](configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k_1024.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_intern_image_h_oneformer_mapillary_300k_1024.pth) | +| OneFormer (emb_dim=1024) | InternImage-H | 52.9 | 67.3 | 67.5 | 1.35B | [config](configs/mapillary_vistas/intern_image/oneformer_intern_image_huge_bs16_300k_1024.yaml) | [model](https://shi-labs.com/projects/oneformer/mapillary/250_16_intern_image_h_oneformer_mapillary_300k_1024.pth) | ## Citation From 458dec8de10b65b6dde11e0ebe82375fdd18c20c Mon Sep 17 00:00:00 2001 From: Jitesh Jain Date: Wed, 2 Oct 2024 23:19:15 -0400 Subject: [PATCH 9/9] :hammer: Fix config in demo --- demo/demo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/demo/demo.py b/demo/demo.py index 3c4d0b0..9dc728e 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -29,6 +29,7 @@ add_swin_config, add_dinat_config, add_convnext_config, + add_internimage_config ) from predictor import VisualizationDemo @@ -44,6 +45,7 @@ def setup_cfg(args): add_dinat_config(cfg) add_convnext_config(cfg) add_oneformer_config(cfg) + add_internimage_config(cfg) cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze()