-
Notifications
You must be signed in to change notification settings - Fork 1
/
hubconf.py
142 lines (104 loc) · 6.09 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5/
Usage:
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
"""
import torch
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
"""Creates a specified YOLOv5 model
Arguments:
name (str): name of model, i.e. 'yolov5s'
pretrained (bool): load pretrained weights into the model
channels (int): number of input channels
classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen
device (str, torch.device, None): device to use for model parameters
Returns:
YOLOv5 pytorch model
"""
from pathlib import Path
from models.yolo import Model
from models.experimental import attempt_load
from utils.general import check_requirements, set_logging
from utils.downloads import attempt_download
from utils.torch_utils import select_device
file = Path(__file__).resolve()
check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
set_logging(verbose=verbose)
save_dir = Path('') if str(name).endswith('.pt') else file.parent
path = (save_dir / name).with_suffix('.pt') # checkpoint path
try:
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
if pretrained and channels == 3 and classes == 80:
model = attempt_load(path, map_location=device) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
ckpt = torch.load(attempt_download(path), map_location=device) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
return model.to(device)
except Exception as e:
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
raise Exception(s) from e
def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
# YOLOv5 custom or local model
return _create(path, autoshape=autoshape, verbose=verbose, device=device)
def yolov5n(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-nano model https://github.com/ultralytics/yolov5
return _create('yolov5n', pretrained, channels, classes, autoshape, verbose, device)
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-small model https://github.com/ultralytics/yolov5
return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-medium model https://github.com/ultralytics/yolov5
return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-large model https://github.com/ultralytics/yolov5
return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-xlarge model https://github.com/ultralytics/yolov5
return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)
def yolov5n6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-nano-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5n6', pretrained, channels, classes, autoshape, verbose, device)
def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)
def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)
def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)
def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)
if __name__ == '__main__':
model = _create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained
# model = custom(path='path/to/model.pt') # custom
# Verify inference
import cv2
import numpy as np
from PIL import Image
from pathlib import Path
imgs = ['data/images/zidane.jpg', # filename
Path('data/images/zidane.jpg'), # Path
'https://ultralytics.com/images/zidane.jpg', # URI
cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
Image.open('data/images/bus.jpg'), # PIL
np.zeros((320, 640, 3))] # numpy
results = model(imgs) # batched inference
results.print()
results.save()