-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
139 lines (112 loc) · 3.56 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import inspect
import pdb
import argparse
import glog
import logging
import sys
import os
import shlex
from chdrft.config.base import is_python2
import chdrft.config.env
import numpy as np
import random
if not is_python2:
from contextlib import ExitStack
class App:
def __init__(self):
self.flags = None
self.stack: ExitStack = None
self.override_flags = {}
self.setup = False
self.cache = None
if not is_python2:
self.global_context = ExitStack()
self.env = chdrft.config.env.g_env
self.env.setup(self)
def setup_jup(self, cmdline='', argv=None, **kwargs):
from chdrft.utils.misc import Attr
argv = shlex.split(cmdline)
self(force=1, argv=argv, **kwargs, keep_open_context=1)
self.setup = True
return Attr(vars(self.flags))
def exit_jup(self):
self.global_context.close()
def __call__(self, force=False, argv=None, parser_funcs=[], keep_open_context=0):
f = inspect.currentframe().f_back
if not force and self.setup: return
if not force and not f.f_globals['__name__'] == '__main__': return
self.setup = True
if 'main' not in f.f_globals and not force: return
parser = None
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--verbosity', type=str, default='ERROR')
parser.add_argument('--pdb', action='store_true')
parser.add_argument('--log_file', type=str)
parser.add_argument('--runid', type=str, default='default')
want_cache = force or ('cache' in f.f_globals and not is_python2)
cache = None
if want_cache:
from chdrft.utils.cache import cache_argparse
cache_argparse(parser)
if 'args' in f.f_globals:
args_func = f.f_globals['args']
args_func(parser)
for x in parser_funcs:
x(parser)
random.seed(0)
np.random.seed(0)
parser.add_argument('other_args', nargs=argparse.REMAINDER, default=['--'])
flags = parser.parse_args(args=argv)
if flags.other_args and flags.other_args[0] == '--':
flags.other_args = flags.other_args[1:]
self.flags = flags
for k, v in self.override_flags.items():
setattr(self.flags, k, v)
glog.setLevel(flags.verbosity)
if flags.log_file:
glog.logger.addHandler(logging.FileHandler(flags.log_file))
if 'flags' in f.f_globals:
f.f_globals['flags'] = flags
if want_cache:
from chdrft.utils.cache import FileCacheDB
self.cache = FileCacheDB.load_from_argparse(flags)
f.f_globals['cache'] = self.cache
if self.stack is not None:
self.stack.close()
main_func = f.f_globals.get('main', None)
def go():
try:
if is_python2:
main_func()
else:
if keep_open_context:
stack = ExitStack()
self.run(stack, main_func)
else:
with ExitStack() as stack:
self.run(stack, main_func)
except Exception as e:
if flags.pdb:
pdb.post_mortem()
raise
except KeyboardInterrupt:
raise
if flags.pdb:
pdb.runcall(go)
else:
go()
self.stack = None
def run(self, stack, main_func):
self.stack = stack
stack.enter_context(self.global_context)
script_name = sys.argv[0]
plog_filename = '/tmp/opa_plog_{}_{}.log'.format(
os.path.basename(script_name), self.flags.runid
)
plog_file = open(plog_filename, 'w')
stack.enter_context(plog_file)
self.plog_file = plog_file
if self.cache:
stack.enter_context(self.cache)
if main_func is not None: main_func()
app = App()