-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
62 lines (43 loc) · 1.41 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import hydra
from omegaconf import OmegaConf
from source.helper.EvalHelper import EvalHelper
from source.helper.FitHelper import FitHelper
from source.helper.PredictHelper import PredictHelper
from source.helper.SeparabilityHelper import SeparabilityHelper
from source.helper.TSNEHelper import TSNEHelper
def fit(params):
fit_helper = FitHelper(params)
fit_helper.perform_fit()
def predict(params):
predict_helper = PredictHelper(params)
predict_helper.perform_predict()
def eval(params):
eval_helper = EvalHelper(params)
eval_helper.perform_eval()
def z_shot_cls(params):
raise NotImplementedError("Not yet implemented.")
def tsne(params):
tsne_helper = TSNEHelper(params)
tsne_helper.perform_tsne()
def separability(params):
separability_helper = SeparabilityHelper(params)
separability_helper.perform_eval()
@hydra.main(config_path="settings/", config_name="settings.yaml")
def perform_tasks(params):
os.chdir(hydra.utils.get_original_cwd())
OmegaConf.resolve(params)
if "fit" in params.tasks:
fit(params)
if "predict" in params.tasks:
predict(params)
if "eval" in params.tasks:
eval(params)
if "z-shot-cls" in params.tasks:
z_shot_cls(params)
if "tsne" in params.tasks:
tsne(params)
if "separability" in params.tasks:
separability(params)
if __name__ == '__main__':
perform_tasks()