-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_ensemble.py
135 lines (118 loc) · 4.56 KB
/
model_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import mmcv
import numpy as np
import torch
from mmcv.parallel import MMDataParallel
from mmcv.parallel.scatter_gather import scatter_kwargs
from mmcv.runner import load_checkpoint, wrap_fp16_model
from PIL import Image
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor
@torch.no_grad()
def main(args):
models = []
gpu_ids = args.gpus
configs = args.config
ckpts = args.checkpoint
weights = args.weights
cfg = mmcv.Config.fromfile(configs[0])
if args.aug_test:
cfg.data.test.pipeline[1].img_ratios = [
0.8, 1.0, 1.25
]
cfg.data.test.pipeline[1].flip = True
else:
cfg.data.test.pipeline[1].img_ratios = [1.0]
cfg.data.test.pipeline[1].flip = False
cfg.data.test.test_mode = True
torch.backends.cudnn.benchmark = True
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=4,
dist=False,
shuffle=False,
)
for idx, (config, ckpt) in enumerate(zip(configs, ckpts)):
cfg = mmcv.Config.fromfile(config)
cfg.model.pretrained = None
cfg.data.test.test_mode = True
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
if cfg.get('fp16', None):
wrap_fp16_model(model)
load_checkpoint(model, ckpt, map_location='cpu')
torch.cuda.empty_cache()
tmpdir = args.out
mmcv.mkdir_or_exist(tmpdir)
model = MMDataParallel(model, device_ids=[gpu_ids[idx % len(gpu_ids)]])
model.eval()
models.append(model)
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
loader_indices = data_loader.batch_sampler
for batch_indices, data in zip(loader_indices, data_loader):
result = []
for model in models:
x, _ = scatter_kwargs(
inputs=data, kwargs=None, target_gpus=model.device_ids)
if args.aug_test:
logits = model.module.aug_test_logits(**x[0])
else:
logits = model.module.simple_test_logits(**x[0])
result.append(logits)
result_logits = 0
for logit in result:
result_logits += logit
pred = result_logits.argmax(axis=1).squeeze()
img_info = dataset.img_infos[batch_indices[0]]
file_name = os.path.join(tmpdir, img_info['filename'])
pred[pred > 0] = 255
Image.fromarray(pred.astype(np.uint8)).save(file_name)
prog_bar.update()
def parse_args():
parser = argparse.ArgumentParser(
description='Model Ensemble with logits result')
parser.add_argument(
'--config', type=str, nargs='+', help='ensemble config files path',
default=[
'configs/knet/knet_s3_upernet_convnext-small_woodscape_960_1.py',
]*6)
parser.add_argument(
'--checkpoint',
type=str,
nargs='+',
help='ensemble checkpoint files path',
default=['work_dirs/knet_s3_upernet_convnext-small_woodscape_960_1/iter_80000.pth',
'work_dirs/knet_s3_upernet_convnext-small_woodscape_960_2/iter_80000.pth',
'work_dirs/knet_s3_upernet_convnext-small_woodscape_960_3/iter_80000.pth',
'work_dirs/knet_s3_upernet_convnext-small_woodscape_syn_960_1/iter_80000.pth',
'work_dirs/knet_s3_upernet_convnext-small_woodscape_syn_960_2/iter_80000.pth',
'work_dirs/knet_s3_upernet_convnext-small_woodscape_syn_960_3/iter_80000.pth'])
parser.add_argument(
'--weights',
type=str,
nargs='+',
help='weights of input checkpoint',
default=[0.75, 0.25])
parser.add_argument(
'--aug-test',
action='store_true',
help='control ensemble aug-result or single-result (default)', default=True)
parser.add_argument(
'--out', type=str, default='results', help='the dir to save result')
parser.add_argument(
'--gpus', type=int, nargs='+', default=[0, 1, 2, 4, 5, 6], help='id of gpu to use')
args = parser.parse_args()
assert len(args.config) == len(args.checkpoint), \
f'len(config) must equal len(checkpoint), ' \
f'but len(config) = {len(args.config)} and' \
f'len(checkpoint) = {len(args.checkpoint)}'
assert args.out, "ensemble result out-dir can't be None"
return args
if __name__ == '__main__':
args = parse_args()
main(args)