-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
105 lines (92 loc) · 3.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# General imports
import os
import sys
import cv2
import time
import glob
import json
import yaml
import scipy
import random
import warnings
warnings.filterwarnings('ignore')
import datetime
import argparse
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
from src.data import prepare_dataset
from src.models import run_kfold
from xgboost import XGBRegressor
from catboost import CatBoostRegressor
from lightgbm import LGBMRegressor
from src.utils import *
from loguru import logger
TIMESTAMP = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--config', dest='config', type=str, help='Path to the config file', default='configs/config.yml')
args = parser.parse_args()
if __name__ == '__main__':
# Config
with open(args.config, "r") as f:
config = AttrDict(yaml.safe_load(f))
config.OUTPUT_PATH = os.path.join(config.OUTPUT_PATH, TIMESTAMP)
if not os.path.exists(config.OUTPUT_PATH):
os.makedirs(config.OUTPUT_PATH)
logger.add(os.path.join(config.OUTPUT_PATH, 'logs.log'))
logger.info(f"Config:{str(config)}")
train_df, test_df = prepare_dataset(
config.DATA_PRODUCTS,
config.TRAIN_METAFILE, config.TEST_METAFILE, config.GRID_METAFILE
)
# dataplot = sns.heatmap(
# train_df[['maiac_AOD_Uncertainty_mean', 'maiac_AOD_Uncertainty_var',
# 'maiac_Column_WV_mean', 'maiac_Column_WV_var', 'maiac_AOD_QA_mean',
# 'maiac_AOD_QA_var', 'maiac_AOD_MODEL_mean', 'maiac_AOD_MODEL_var',
# 'misr_Aerosol_Optical_Depth_mean', 'row_nan_count',
# 'mean_value', 'elevation_mean', 'elevation_var', 'month', 'day',
# 'label']].corr(), cmap="YlGnBu"
# )
# plt.savefig(os.path.join(config.OUTPUT_PATH, 'feature_correlation.png'))
# plt.show()
train_labels = train_df['label'].to_numpy()
train_df = train_df.drop(['label'], axis=1)
test_df = test_df.drop(['label'], axis=1)
train_features = train_df.to_numpy()
test_features = test_df.to_numpy()
logger.info(f"Using following features: {list(train_df.columns)}")
logger.info(f"Found {len(train_features)} training instances")
if config.MODEL == 'xgboost':
model = XGBRegressor
model_params = config.XGB_PARAMS
elif config.MODEL == 'catboost':
model = CatBoostRegressor
model_params = config.CATB_PARAMS
elif config.MODEL == 'lightgbm':
model = LGBMRegressor
model_params = config.LGBM_PARAMS
else:
raise ValueError(f"Model {config.MODEL} not supported")
logger.info(f"Training {config.MODEL} model")
train_preds, test_preds, oof_preds, oof_labels, feat_importances = run_kfold(
train_features, train_labels, test_features, config.N_FOLDS,
model, model_params, config.OUTPUT_PATH, name=config.MODEL
)
metrics = compute_metrics(train_preds, train_labels)
for k, v in metrics.items():
logger.info(f"Average train_{k}: {np.mean(v)}")
metrics = compute_metrics(np.array(oof_preds), np.array(oof_labels))
for k, v in metrics.items():
logger.info(f"Average eval_{k}: {np.mean(v)}")
submission = pd.read_csv(config.TEST_METAFILE)
submission['value'] = test_preds
submission.to_csv(os.path.join(config.OUTPUT_PATH, f'submission_{TIMESTAMP}.csv'), index=False)
submission.head()
plt.barh(train_df.columns, feat_importances)
plt.yticks(fontsize='xx-small')
plt.savefig(os.path.join(config.OUTPUT_PATH, 'feature_importance.png'))
plt.show()