-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
150 lines (120 loc) · 5.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
from scipy import signal
class EpochBuffer:
def __init__(self, fixed_length=400):
self.buffer = [] # filter raw data
self.labels = []
self.tmp_data = [] # 3000
self.buffer_2 = [] # down sample data
self.sample_rate = 100
self.fixed_list = [0 for i in range(fixed_length + 10)]
self.blank = [0 for i in range(10)]
def set(self, data, label, sample_rate): # 存储数据
if len(self.buffer) == 0:
self.buffer = self.buffer + data
else:
self.buffer = self.buffer + data[-100:]
self.labels = self.labels + label
self.tmp_data = data
self.sample_rate = sample_rate
return None
def set_raw_data(self, data): # 存储数据
self.buffer = self.buffer + data
return None
def set_data(self, data): # 存储下采样数据
self.buffer_2 = self.buffer_2 + data
return None
def set_label(self, label): # 存储预测结果
self.labels = self.labels + label
def get_raw_data(self, index):
# 过滤前10s,防止初始干扰
return self.buffer[index * 250 + 2500: index * 250 + 10000]
def get_raw_data_state(self, index): # 获取数据状态
if len(self.buffer) >= 2500 + index * 250 + 7500:
return True
else:
return False
def get_data(self, fixed_length, moved_length, f_m_ratio, i):
f_m_ratio = int(f_m_ratio)
a = i % f_m_ratio
b = i // f_m_ratio * fixed_length
y = self.buffer_2[b + moved_length * a: moved_length * (1 + a) + b] + self.blank
self.fixed_list[moved_length * a:moved_length * (1 + a) + 10] = y
return self.fixed_list, self.sample_rate
# 如果buffer_2的数据 ≥ index * 100就返回True
def get_data_state(self, index, moved_length):
if len(self.buffer_2) <= 0 + int(index * moved_length):
return False
else:
return True
def get_label(self, index):
return self.labels[index]
def get_label_state(self, index):
if len(self.labels) > index:
return True
else:
return False
def print_information(self):
print(len(self.buffer), len(self.labels), self.sample_rate)
def get_filter_rhythm(self, l_f=4, h_f=8, ): # 1/2 * sample_rate = 50
b, a = signal.butter(8, [l_f / 50.0, h_f / 50.0], btype='bandpass', analog=False) # 4Hz-8Hz
filter_data = signal.filtfilt(b, a, np.array(self.fixed_list)) # numpy.ndarray
return filter_data.tolist()
epoch_buffer = EpochBuffer(400)
# 作用:输出以batch_size大小为一组的(x,y)输出结果是很多个batch
def batch_data(x, batch_size):
print("the length of x and batch_size: ", len(x), batch_size)
shuffle = np.random.permutation(len(x)) # 打乱排序,0,1,2……,x
start = 0
# from IPython.core.debugger import Tracer; Tracer()()
x = x[shuffle]
while start + batch_size <= len(x):
yield x[start:start + batch_size]
start += batch_size
def flatten(name, input_var):
dim = 1
for d in input_var.get_shape()[1:].as_list():
dim *= d
# print("this is dim", dim, name)
output_var = tf.reshape(input_var, shape=[-1, dim], name=name)
# shape(-1,dim)这里-1代表Numpy会根据dim维度(有多少列)自动求出-1这个维度(有多少行)的数量
# 这个shape=[-1, dim]得到的是个一维的行向量dim列
# print("look here *********************", input_var.get_shape(), input_var.get_shape()[1:])
# print("look here #####################", output_var.get_shape(), output_var.get_shape()[1:])
return output_var
def sample_arr(sample, channel=1):
tmp = []
for i in range(len(sample)):
tmp.append(sample[i][channel - 1])
return tmp
def filter(data):
data = np.array(data)
filter_detrend = signal.detrend(data) # baseline drift
notch_b, notch_a = signal.iirnotch(0.4, 30.0)
filter_data_1 = signal.filtfilt(notch_b, notch_a, filter_detrend)
b, a = signal.butter(8, [0.004, 0.4], btype='bandpass', analog=False) # 0.5-50Hz
filter_data_2 = signal.filtfilt(b, a, filter_data_1) # numpy.ndarray
data_list = filter_data_2.tolist() # list
return data_list
def down_sample(data_list):
"""the data_list length is 7500(30s epoch)
Have been abandoned
"""
data = np.zeros(3000)
matrix = np.array([[1.0 / 3, 0], [1.0 / 3, 0], [1.0 / 6, 1.0 / 6], [0, 1.0 / 3], [0, 1.0 / 3]])
if len(data_list) < 7500: # 7500
data_list = data_list + [data_list[-1] for index in range(7500 - len(data_list))]
data_arr = np.array(data_list)
for i in range(1500):
data[2 * i: 2 * i + 2] = np.around(np.dot(data_arr[5 * i: 5 * i + 5], matrix),
3) # Keep three significant digits
data = np.reshape(data, (1, 3000))
return data # arrary
def print_n_samples_each_class(labels, classes):
class_dict = dict(zip(range(len(classes)), classes))
unique_labels = np.unique(labels)
for c in unique_labels:
n_samples = len(np.where(labels == c)[0])
print ("{}: {}".format(class_dict[c], n_samples))