-
Notifications
You must be signed in to change notification settings - Fork 1
/
regression.py
35 lines (27 loc) · 1.05 KB
/
regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import sklearn
import logging
import csv
import argparse
import pandas as pd
import numpy as np
from scipy.io import loadmat
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.feature_selection import VarianceThreshold
from sklearn.feature_selection import RFECV
from sklearn import svm
from sklearn.calibration import CalibratedClassifierCV
import paths
import utils
from classify_using_tractographic_feature import get_weighted_connectivity_feature_vectors_train
# setup logs
log = os.path.join(os.getcwd(), 'log_regression.txt')
fmt = '%(asctime)s %(message)s'
logging.basicConfig(level=logging.INFO, format=fmt, filename=log)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(logging.Formatter(fmt))
logging.getLogger('').addHandler(console)
logging.info('loading training set...')
pat_names_train, gt, W_dsi_pass, W_nrm_pass, W_bin_pass, W_dsi_end, W_nrm_end, W_bin_end = get_weighted_connectivity_feature_vectors_train(mode='gt', region='roi')