forked from openai/supervised-reptile
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_omniglot.py
44 lines (34 loc) · 1.37 KB
/
run_omniglot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Train a model on Omniglot.
"""
import random
import tensorflow as tf
from supervised_reptile.args import argument_parser, model_kwargs, train_kwargs, evaluate_kwargs
from supervised_reptile.eval import evaluate
from supervised_reptile.models import OmniglotModel
from supervised_reptile.omniglot import read_dataset, split_dataset, augment_dataset
from supervised_reptile.train import train
DATA_DIR = 'data/omniglot'
def main():
"""
Load data and train a model on it.
"""
args = argument_parser().parse_args()
random.seed(args.seed)
train_set, test_set = split_dataset(read_dataset(DATA_DIR))
train_set = list(augment_dataset(train_set))
test_set = list(test_set)
model = OmniglotModel(args.classes, **model_kwargs(args))
with tf.Session() as sess:
if not args.pretrained:
print('Training...')
train(sess, model, train_set, test_set, args.checkpoint, **train_kwargs(args))
else:
print('Restoring from checkpoint...')
tf.train.Saver().restore(sess, tf.train.latest_checkpoint(args.checkpoint))
print('Evaluating...')
eval_kwargs = evaluate_kwargs(args)
print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
print('Test accuracy: ' + str(evaluate(sess, model, test_set, **eval_kwargs)))
if __name__ == '__main__':
main()