-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
64 lines (48 loc) · 1.97 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from tqdm import tqdm
from options.train_options import TrainOptions
from datasets import CreateDataLoader
from models import create_model
from utils.visualizer import Visualizer
from utils import util
def get_state_message(epoch, n_iter, param_dict):
log = "Epoch: {}, n_iter: {}\n".format(epoch, n_iter)
for name, value in param_dict.items():
log += "{}: {:.5f}, ".format(name, value)
return log
def train():
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)
model.setup(opt)
util.mkdir(model.save_dir)
visualizer = Visualizer(opt)
n_iter = (opt.epoch_count - 1) * dataset_size
total_epoch = opt.load_epoch + opt.epoch_decay + 1
model.update_epoch(opt.load_epoch)
model.update_niter(n_iter)
for epoch in range(opt.epoch_count, total_epoch):
for data in tqdm(dataset, total=dataset_size, ascii=True):
model.set_input(data)
model.optimize_parameters()
n_iter += 1
if n_iter % opt.display_freq == 0:
for name, image in model.get_current_visuals().items():
visualizer.add_image(name, image, n_iter)
if n_iter % opt.print_freq == 0:
for name, value in model.get_current_losses().items():
visualizer.add_scalar(name, value, n_iter)
log = get_state_message(epoch, n_iter, model.get_current_losses())
tqdm.write(log)
if n_iter % opt.save_lastest_freq == 0:
model.save_networks('latest')
model.update_niter(n_iter)
log = get_state_message(epoch, n_iter, model.get_current_losses())
visualizer.add_log(log)
if epoch % opt.save_epoch_freq == 0:
model.save_networks(epoch)
model.update_learning_rate()
model.update_epoch(epoch)
if __name__ == '__main__':
train()