-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
executable file
·109 lines (88 loc) · 4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from argparse import ArgumentParser
import math
import os
import yaml
import torch
import numpy as np
import random
from model.model import GeoMol
from model.training import train, test, NoamLR
from utils import create_logger, dict_to_str, plot_train_val_loss, save_yaml_file, get_optimizer_and_scheduler
from model.featurization import construct_loader
from model.parsing import parse_train_args, set_hyperparams
from deepergcn.DeeperGCN import DeeperGCN
from torch.utils.tensorboard import SummaryWriter
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
# torch.multiprocessing.set_sharing_strategy('file_system')
# add training args
args = parse_train_args()
logger = create_logger('train', args.log_dir)
logger.info('Arguments are...')
for arg in vars(args):
logger.info(f'{arg}: {getattr(args, arg)}')
# seeds
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
# construct loader and set device
train_loader, val_loader = construct_loader(args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# build model
embed_network = None
deepergcn = None
if args.restart_dir:
with open(f'{args.restart_dir}/model_parameters.yml') as f:
model_parameters = yaml.full_load(f)
model = GeoMol(**model_parameters).to(device)
state_dict = torch.load(f'{args.restart_dir}/best_model.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=True)
else:
hyperparams = set_hyperparams(args)
model_parameters = {'hyperparams': hyperparams,
'num_node_features': train_loader.dataset.num_node_features,
'num_edge_features': train_loader.dataset.num_edge_features}
model = GeoMol(**model_parameters).to(device)
if args.MolR_emb:
embed_network = model.load_pretrained_emb(args.embed_path, )
if args.utilize_deepergcn:
deepergcn = DeeperGCN(input_dim=train_loader.dataset.num_node_features+args.random_vec_dim, output_dim=1, edge_attr_dim=4+args.random_vec_dim, hidden_channels=args.embeddings_dim, num_layers=28).to(device)
# Load deepergcn for the given molecular_property
path = 'deepergcn/saved/' + args.dg_molecular_property + '_' + str(args.embeddings_dim) + '.pt'
print(path)
checkpoint = torch.load(path)
deepergcn.load_state_dict(checkpoint['model_state_dict'])
# get optimizer and scheduler
optimizer, scheduler = get_optimizer_and_scheduler(args, model, len(train_loader.dataset))
# record parameters
logger.info(f'\nModel parameters are:\n{dict_to_str(model_parameters)}\n')
yaml_file_name = os.path.join(args.log_dir, 'model_parameters.yml')
save_yaml_file(yaml_file_name, model_parameters)
# instantiate summary writer
writer = SummaryWriter(args.log_dir)
best_val_loss = math.inf
best_epoch = 0
logger.info("Starting training...")
for epoch in range(1, args.n_epochs):
train_loss = train(model, train_loader, optimizer, device, scheduler, logger if args.verbose else None, epoch, writer, embed_network=embed_network, deepergcn=deepergcn)
logger.info("Epoch {}: Training Loss {}".format(epoch, train_loss))
val_loss = test(model, val_loader, device, epoch, writer, embed_network=embed_network, deepergcn=deepergcn)
logger.info("Epoch {}: Validation Loss {}".format(epoch, val_loss))
if scheduler and not isinstance(scheduler, NoamLR):
scheduler.step(val_loss)
if val_loss <= best_val_loss:
best_val_loss = val_loss
best_epoch = epoch
torch.save(model.state_dict(), os.path.join(args.log_dir, 'best_model.pt'))
torch.save({
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
}, os.path.join(args.log_dir, 'last_model.pt'))
logger.info("Best Validation Loss {} on Epoch {}".format(best_val_loss, best_epoch))
log_file = os.path.join(args.log_dir, 'train.log')
plot_train_val_loss(log_file)
with open('train_result.txt', 'w') as f:
f.write('Done')