-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_grid.py
86 lines (75 loc) · 3.18 KB
/
make_grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pathlib
import paddle
import warnings
import math
import numpy as np
from PIL import Image
from typing import Union, Optional, List, Tuple, Text, BinaryIO
@paddle.no_grad()
def make_grid(tensor: Union[paddle.Tensor, List[paddle.Tensor]],
nrow: int=8,
padding: int=2,
normalize: bool=False,
value_range: Optional[Tuple[int, int]]=None,
scale_each: bool=False,
pad_value: int=0,
**kwargs) -> paddle.Tensor:
if not (isinstance(tensor, paddle.Tensor) or
(isinstance(tensor, list) and all(
isinstance(t, paddle.Tensor) for t in tensor))):
raise TypeError(
f'tensor or list of tensors expected, got {type(tensor)}')
if "range" in kwargs.keys():
warning = "range will be deprecated, please use value_range instead."
warnings.warn(warning)
value_range = kwargs["range"]
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
tensor = paddle.stack(tensor, axis=0)
if len(tensor.shape) == 2: # single image H x W
tensor = tensor.unsqueeze(0)
if len(tensor.shape) == 3: # single image
if tensor.shape[0] == 1: # if single-channel, convert to 3-channel
tensor = paddle.concat((tensor, tensor, tensor), 0)
tensor = tensor.unsqueeze(0)
if len(tensor.shape) == 4 and tensor.shape[1] == 1: # single-channel images
tensor = paddle.concat((tensor, tensor, tensor), 1)
if normalize is True:
if value_range is not None:
assert isinstance(value_range, tuple), \
"value_range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, low, high):
img.clip(min=low, max=high)
img = img - low
img = img / max(high - low, 1e-5)
def norm_range(t, value_range):
if value_range is not None:
norm_ip(t, value_range[0], value_range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, value_range)
else:
norm_range(tensor, value_range)
if tensor.shape[0] == 1:
return tensor.squeeze(0)
# make the mini-batch of images into a grid
nmaps = tensor.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] +
padding)
num_channels = tensor.shape[1]
grid = paddle.full((num_channels, height * ymaps + padding,
width * xmaps + padding), pad_value, dtype='float32')
tensor = tensor.astype('float32')
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
grid[:, y * height + padding:(y + 1) * height, x * width + padding:(
x + 1) * width] = tensor[k]
k = k + 1
return grid