-
Notifications
You must be signed in to change notification settings - Fork 7
/
ablstm.py
70 lines (61 loc) · 3.03 KB
/
ablstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 10 22:02:40 2019
@author: Chonghua Xue (Kolachalama's Lab, BU)
"""
import argparse
from utils_miscellany import load_config_xml
from model import ModelLSTM
if __name__ == '__main__':
# main parser
parser = argparse.ArgumentParser(description='Quantifying the nativeness of antibody sequences using long short-term memory network.')
subparsers = parser.add_subparsers(dest='cmd')
# fit cmd parser
parser_fit = subparsers.add_parser('fit')
parser_fit.add_argument('TRN_FN', help='training data file')
parser_fit.add_argument('VLD_FN', help='validation data file')
parser_fit.add_argument('SAVE_FP', help='model save path')
parser_fit.add_argument('-l', default='', help='model file to load (default: \"\")')
parser_fit.add_argument('-c', default='./ablstm.config', help='configuration XML file (default: \"./ablstm.config\")')
parser_fit.add_argument('-d', default='cpu', help='device (default: \"cpu\")')
# eval cmd parser
parser_eval = subparsers.add_parser('eval')
parser_eval.add_argument('TST_FN', help='evaluation data file')
parser_eval.add_argument('MDL_FN', help='model file to load')
parser_eval.add_argument('SCR_FN', help='file to save scores')
parser_eval.add_argument('-c', default='./ablstm.config', help='configuration XML file (default: \"./ablstm.config\")')
parser_eval.add_argument('-d', default='cpu', help='device (default: \"cpu\")')
# args is stored in Namespace obj and configuration in dict
args = parser.parse_args()
conf = load_config_xml(args.c)
if args.cmd == 'fit':
# initialize model
if not args.l:
param_ini = {'embedding_dim': conf['__init__']['embedding_dim'],
'hidden_dim': conf['__init__']['hidden_dim'],
'gapped': conf['__init__']['gapped'],
'fixed_len': conf['__init__']['fixed_len'],
'device': args.d}
model = ModelLSTM(**param_ini)
else:
param_ini = {'device': args.d}
model = ModelLSTM(**param_ini)
model.load(args.l)
# fit model
param_fit = {'trn_fn': args.TRN_FN,
'vld_fn': args.VLD_FN,
'n_epoch': conf['fit']['n_epoch'],
'trn_batch_size': conf['fit']['trn_batch_size'],
'vld_batch_size': conf['fit']['vld_batch_size'],
'lr': conf['fit']['lr'],
'save_fp': args.SAVE_FP}
model.fit(**param_fit)
elif args.cmd == 'eval':
param_ini = {'device': args.d}
model = ModelLSTM(**param_ini)
model.load(args.MDL_FN)
param_eval = {'fn': args.TST_FN, 'batch_size': conf['eval']['batch_size']}
scores = model.eval(**param_eval)
scores = [str(s) for s in scores]
with open(args.SCR_FN, 'w') as f:
f.write(','.join(scores))