-
Notifications
You must be signed in to change notification settings - Fork 0
/
CaptchaDataset.py
29 lines (25 loc) · 1.13 KB
/
CaptchaDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch, random
from torchvision.transforms.functional import to_tensor
from torch.utils.data.dataset import Dataset
from ImageCaptchaEnhanced import ImageCaptchaEnhanced
class CaptchaDataset(Dataset):
'''
The captcha data set for PyTorch, generate captcha only required
'''
def __init__(self, characters, length, width, height, label_length):
super(CaptchaDataset, self).__init__()
self.characters = characters
self.length = length
self.width = width
self.height = height
self.label_length = label_length
self.n_class = len(characters)
self.generator = ImageCaptchaEnhanced(width=width, height=height)
def __len__(self):
return self.length
def __getitem__(self, index):
random_str_index_set = [random.randrange(0,self.n_class) for _ in range(0,self.label_length)]
img = "".join([self.characters[x][random.randint(0,1)] for x in random_str_index_set])
image = to_tensor(self.generator.generate_image("".join(img)))
target = torch.tensor(random_str_index_set, dtype=torch.int64)
return image, target