-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
69 lines (54 loc) · 1.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import distutils
import json
from types import SimpleNamespace as Namespace
def getattr_recursive(obj, s):
if isinstance(s, list):
split = s
else:
split = s.split('/')
return getattr_recursive(getattr(obj, split[0]), split[1:]) if len(split) > 1 else getattr(obj, split[0])
def setattr_recursive(obj, s, val):
if isinstance(s, list):
split = s
else:
split = s.split('/')
return setattr_recursive(getattr(obj, split[0]), split[1:], val) if len(split) > 1 else setattr(obj, split[0], val)
def generate_config(params, file_path):
print("Saving Configs")
f = open(file_path, "w")
json_data = json.dumps(params.__dict__, default=lambda o: o.__dict__, indent=4)
f.write(json_data)
f.close()
def read_config(config_path):
print('Parse Params file here from ', config_path, ' and pass into main')
json_data = open(config_path, "r").read()
return json.loads(json_data, object_hook=lambda d: Namespace(**d))
def override_params(params, overrides):
assert (len(overrides) % 2 == 0)
for k in range(0, len(overrides), 2):
oldval = getattr_recursive(params, overrides[k])
if type(oldval) == bool:
to_val = bool(distutils.util.strtobool(overrides[k + 1]))
else:
to_val = type(oldval)(overrides[k + 1])
setattr_recursive(params, overrides[k],
to_val)
print("Overriding param", overrides[k], "from", oldval, "to", to_val)
return params
def get_bool_user(message, default: bool):
if default:
default_string = '[Y/n]'
else:
default_string = '[y/N]'
resp = input('{} {}\n'.format(message, default_string))
try:
if distutils.util.strtobool(resp):
return True
else:
return False
except ValueError:
return default
if __name__ == "__main__":
file_loc = "config/manhattan32_cpp.json"
res = read_config(config_path=file_loc)
print(res)