-
Notifications
You must be signed in to change notification settings - Fork 0
/
clf_model.py
71 lines (49 loc) · 1.87 KB
/
clf_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pandas as pd
import numpy as np
import pickle
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
class ClassifyLabel:
def __init__(self, input, setting=None):
# Save input as class variables
self.X_train_val, self.y_train_val, self.X_test = input
# Pre-declare the variables for future functions
self.clf_model = None
self.y_pred = None
# Default SVC hyper-parameters: non-linear
self.default_setting = {'kernel': 'rbf', 'gamma': 0.7}
# Classification model
if not setting:
setting = self.default_setting
self.clf_model = SVC(**setting, random_state=42)
# Validate by 5-fold cross-validation
def valdiation(self):
model = self.clf_model
# Declare KFold
kf = KFold(n_splits=5, shuffle=True, random_state=42)
# 5-fold cross-validation
cv_results = cross_val_score(model, self.X_train_val, self.y_train_val, cv=kf, scoring='accuracy')
# Accuracy of each fold
for i, accuracy in enumerate(cv_results, start=1):
print(f'Fold {i}: Accuracy = {accuracy:.4f}')
# Result of whole cross-validation
mean_accuracy = cv_results.mean()
print(f'Mean Accuracy: {mean_accuracy:.4f}')
return cv_results, mean_accuracy
# Fit
def fit(self) -> None:
model = self.clf_model
# Fit model
model.fit(self.X_train_val, self.y_train_val)
self.clf_model = model
return
# Predict
def predict(self) -> np.array:
model = self.clf_model
# Predict pattern labels
y_pred = model.predict(self.X_test)
# Save as pickle
with open('results/predicted_labels.pickle', 'wb') as file:
pickle.dump(y_pred, file)
return y_pred