-
Notifications
You must be signed in to change notification settings - Fork 5
/
batch_transforms.py
135 lines (101 loc) · 4.48 KB
/
batch_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# https://github.com/pratogab/batch-transforms
import torch
class ToTensor:
"""Applies the :class:`~torchvision.transforms.ToTensor` transform to a batch of images.
"""
def __init__(self):
self.max = 255
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of size (N, C, H, W) to be tensorized.
Returns:
Tensor: Tensorized Tensor.
"""
return tensor.float().div_(self.max)
class Normalize:
"""Applies the :class:`~torchvision.transforms.Normalize` transform to a batch of images.
.. note::
This transform acts out of place by default, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
dtype (torch.dtype,optional): The data type of tensors to which the transform will be applied.
device (torch.device,optional): The device of tensors to which the transform will be applied.
"""
def __init__(self, mean, std, inplace=False, dtype=torch.float, device='cpu'):
self.mean = torch.as_tensor(mean, dtype=dtype, device=device)[None, :, None, None]
self.std = torch.as_tensor(std, dtype=dtype, device=device)[None, :, None, None]
self.inplace = inplace
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of size (N, C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor.
"""
if not self.inplace:
tensor = tensor.clone()
tensor.sub_(self.mean).div_(self.std)
return tensor
class RandomHorizontalFlip:
"""Applies the :class:`~torchvision.transforms.RandomHorizontalFlip` transform to a batch of images.
.. note::
This transform acts out of place by default, i.e., it does not mutate the input tensor.
Args:
p (float): probability of an image being flipped.
inplace(bool,optional): Bool to make this operation in-place.
"""
def __init__(self, p=0.5, inplace=False):
self.p = p
self.inplace = inplace
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of size (N, C, H, W) to be flipped.
Returns:
Tensor: Randomly flipped Tensor.
"""
if not self.inplace:
tensor = tensor.clone()
flipped = torch.rand(tensor.size(0)) < self.p
tensor[flipped] = torch.flip(tensor[flipped], [3])
return tensor
class RandomCrop:
"""Applies the :class:`~torchvision.transforms.RandomCrop` transform to a batch of images.
Args:
size (int): Desired output size of the crop.
padding (int, optional): Optional padding on each border of the image.
Default is None, i.e no padding.
device (torch.device,optional): The device of tensors to which the transform will be applied.
"""
def __init__(self, size, padding=None, device='cpu'):
self.size = size
self.padding = padding
self.device = device
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of size (N, C, H, W) to be cropped.
Returns:
Tensor: Randomly cropped Tensor.
"""
if self.padding is not None:
padded = torch.zeros((tensor.size(0), tensor.size(1), tensor.size(2) + self.padding * 2,
tensor.size(3) + self.padding * 2), dtype=tensor.dtype, device=self.device)
padded[:, :, self.padding:-self.padding, self.padding:-self.padding] = tensor
else:
padded = tensor
h, w = padded.size(2), padded.size(3)
th, tw = self.size, self.size
if w == tw and h == th:
i, j = 0, 0
else:
i = torch.randint(0, h - th + 1, (tensor.size(0),), device=self.device)
j = torch.randint(0, w - tw + 1, (tensor.size(0),), device=self.device)
rows = torch.arange(th, dtype=torch.long, device=self.device) + i[:, None]
columns = torch.arange(tw, dtype=torch.long, device=self.device) + j[:, None]
padded = padded.permute(1, 0, 2, 3)
padded = padded[:, torch.arange(tensor.size(0))[:, None, None], rows[:, torch.arange(th)[:, None]], columns[:, None]]
return padded.permute(1, 0, 2, 3)