-
Notifications
You must be signed in to change notification settings - Fork 115
/
process_data.py
66 lines (49 loc) · 2.1 KB
/
process_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy
from collections import Counter
from keras.preprocessing.sequence import pad_sequences
import pickle
import platform
def load_data():
train = _parse_data(open('data/train_data.data', 'rb'))
test = _parse_data(open('data/test_data.data', 'rb'))
word_counts = Counter(row[0].lower() for sample in train for row in sample)
vocab = [w for w, f in iter(word_counts.items()) if f >= 2]
chunk_tags = ['O', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC', "B-ORG", "I-ORG"]
# save initial config data
with open('model/config.pkl', 'wb') as outp:
pickle.dump((vocab, chunk_tags), outp)
train = _process_data(train, vocab, chunk_tags)
test = _process_data(test, vocab, chunk_tags)
return train, test, (vocab, chunk_tags)
def _parse_data(fh):
# in windows the new line is '\r\n\r\n' the space is '\r\n' . so if you use windows system,
# you have to use recorsponding instructions
if platform.system() == 'Windows':
split_text = '\r\n'
else:
split_text = '\n'
string = fh.read().decode('utf-8')
data = [[row.split() for row in sample.split(split_text)] for
sample in
string.strip().split(split_text + split_text)]
fh.close()
return data
def _process_data(data, vocab, chunk_tags, maxlen=None, onehot=False):
if maxlen is None:
maxlen = max(len(s) for s in data)
word2idx = dict((w, i) for i, w in enumerate(vocab))
x = [[word2idx.get(w[0].lower(), 1) for w in s] for s in data] # set to <unk> (index 1) if not in vocab
y_chunk = [[chunk_tags.index(w[1]) for w in s] for s in data]
x = pad_sequences(x, maxlen) # left padding
y_chunk = pad_sequences(y_chunk, maxlen, value=-1)
if onehot:
y_chunk = numpy.eye(len(chunk_tags), dtype='float32')[y_chunk]
else:
y_chunk = numpy.expand_dims(y_chunk, 2)
return x, y_chunk
def process_data(data, vocab, maxlen=100):
word2idx = dict((w, i) for i, w in enumerate(vocab))
x = [word2idx.get(w[0].lower(), 1) for w in data]
length = len(x)
x = pad_sequences([x], maxlen) # left padding
return x, length