-
Notifications
You must be signed in to change notification settings - Fork 0
/
parti_train1.py
34 lines (27 loc) · 921 Bytes
/
parti_train1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from parti_pytorch import VitVQGanVAE, VQGanVAETrainer
import os
import configparser
# Loading configurations
configParser = configparser.RawConfigParser()
configFilePath = r'configuration.txt'
configParser.read(configFilePath)
datasetPathVideo = configParser.get('COMMON', 'datasetPathVideo')
vit_vae = VitVQGanVAE(
dim = 256, # dimensions
image_size = 128, # target image size
patch_size = 256, # size of the patches in the image attending to each other
num_layers = 3 # number of layers
).cuda()
trainer = VQGanVAETrainer(
vit_vae,
folder = '/media/gamal/Passport/Datasets/VoxCeleb2Test/Voxceleb2TestFaces',
num_train_steps = 100000,
lr = 3e-4,
batch_size = 4,
grad_accum_every = 8,
amp = True,
results_folder = '/media/gamal/Passport/parti/vae',
save_results_every = 10,
save_model_every = 50,
)
trainer.train()