Skip to content

Commit

Permalink
add livecell dino
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Feb 9, 2023
1 parent f6a629f commit ff0e6fe
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 0 deletions.
114 changes: 114 additions & 0 deletions configs/_base_/datasets/livecell/livecell_detection_dino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/livecell/'
metainfo = dict(
classes=[
'shsy5y', 'a172', 'bt474', 'bv2', 'huh7', 'mcf7', 'skov3', 'skbr3'
],
palette=[(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70)])
file_client_args = dict(backend='disk')

multi = 1536 / 1333
plus = 400

train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[
(int((s[0] + plus) * multi), int(s[1] * multi))
for s in [(480, 1333), (512, 1333), (
544, 1333), (576, 1333), (608, 1333), (
640, 1333), (672, 1333), (
704, 1333), (736, 1333), (768,
1333), (800,
1333)]
],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(int((s[0] + plus) * multi), int(s[1] * multi))
for s in [(400, 4200), (500, 4200), (600, 4200)]],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(int(
(384 + plus * 600 / 1333) * multi), int(600 * multi)),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[
(int((s[0] + plus) * multi), int(s[1] * multi))
for s in [(480, 1333), (512, 1333), (
544, 1333), (576, 1333), (608, 1333), (
640, 1333), (672, 1333), (
704, 1333), (736, 1333), (768,
1333), (800,
1333)]
],
keep_ratio=True)
]
]),
dict(type='PackDetInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='Resize',
scale=(int(1333 * multi), int((800 + plus) * multi)),
keep_ratio=True),
# If you don't have a gt annotation, delete the pipeline
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=1,
num_workers=0,
persistent_workers=False,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
data_root=data_root,
metainfo=metainfo,
ann_file='livecell_coco_train_8class.json',
data_prefix=dict(img='images/livecell_train_val_images/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=0,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
metainfo=metainfo,
data_root=data_root,
ann_file='livecell_coco_val_8class.json',
data_prefix=dict(img='images/livecell_train_val_images/'),
test_mode=True,
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(
type='CocoFastMetric',
ann_file=data_root + 'livecell_coco_val_8class.json',
metric='bbox',
proposal_nums=(100, 300, 3000))
test_evaluator = val_evaluator
33 changes: 33 additions & 0 deletions configs/_base_/schedules/dino/dino_36e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001, # 0.0002 for DeformDETR
weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})
) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa

# learning policy
max_epochs = 36
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[30],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(enable=True, base_batch_size=16)
11 changes: 11 additions & 0 deletions configs/projects/livecell/dino/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# DINO

[model page](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/dino/README.md)

## Results and Models

#### Box Results

| Backbone | box AP | Config | Download |
| :----------------------: | :----: | :-----------------------------------: | :----------------------------------------------------------------------------------------------------------------------------: |
| dino-4scale_r50_livecell | 52.3 | [config](dino-4scale_r50_livecell.py) | [model](https://github.com/okotaku/dethub-weights/releases/download/v0.1.1dino-livecell/dino-4scale_r50_livecell-535173b5.pth) |
17 changes: 17 additions & 0 deletions configs/projects/livecell/dino/dino-4scale_r50_livecell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'mmdet::_base_/default_runtime.py',
'../../../_base_/models/dino/dino-4scale_r50.py',
'../../../_base_/datasets/livecell/livecell_detection_dino.py',
'../../../_base_/schedules/dino/dino_36e.py'
]
custom_imports = dict(imports=['dethub'], allow_failed_imports=False)
fp16 = dict(loss_scale=512.)

# model settings
num_classes = 8
model = dict(
num_queries=3000,
bbox_head=dict(num_classes=num_classes),
test_cfg=dict(max_per_img=3000))

load_from = 'https://download.openmmlab.com/mmdetection/v3.0/dino/dino-4scale_r50_8xb2-12e_coco/dino-4scale_r50_8xb2-12e_coco_20221202_182705-55b2bba2.pth' # noqa
13 changes: 13 additions & 0 deletions configs/projects/livecell/dino/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Collections:
- Name: DINO

Models:
- Name: dino-4scale_r50_livecell
In Collection: DINO
Config: configs/projects/livecell/dino/dino-4scale_r50_livecell.py
Results:
- Task: Object Detection
Dataset: livecell
Metrics:
box AP: 52.3
Weights: https://github.com/okotaku/dethub-weights/releases/download/v0.1.1dino-livecell/dino-4scale_r50_livecell-535173b5.pth
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Import:
- configs/projects/findfallenpeople/yolox/metafile.yml
- configs/projects/gbr_cots/yolox/metafile.yml
- configs/projects/le2i/yolox/metafile.yml
- configs/projects/livecell/dino/metafile.yml
- configs/projects/livecell/yolox/metafile.yml
- configs/projects/lvis/dino/metafile.yml
- configs/projects/lvis/yolox/metafile.yml
Expand Down

0 comments on commit ff0e6fe

Please sign in to comment.