-
Notifications
You must be signed in to change notification settings - Fork 3
/
create_subset.py
73 lines (56 loc) · 1.51 KB
/
create_subset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import random
data = []
labels = []
with open("data_extraction/abstracts_pdf.tsv") as f:
for line in f.readlines():
data.append(line.split('\t'))
with open("data_extraction/abstract_pdf_anns.csv") as f:
for line in f.readlines():
labels.append(line.split(','))
random.shuffle(data)
train_subset = data[:3000]
test_subset = data[3000:4000]
validation_subset = data[4000:5000]
train_labels = []
test_labels = []
validation_labels = []
train_set = set([line[0] for line in train_subset])
test_set = set([line[0] for line in test_subset])
validation_set = set([line[0] for line in validation_subset])
for label in labels:
if label[0] in train_set:
train_labels.append(label)
if label[0] in test_set:
test_labels.append(label)
if label[0] in validation_set:
validation_labels.append(label)
with open("train_data.tsv", "w") as f:
text = ""
for line in train_subset:
text += "\t".join(line)
f.write(text)
with open("test_data.tsv", "w") as f:
text = ""
for line in test_subset:
text += "\t".join(line)
f.write(text)
with open("validation_data.tsv", "w") as f:
text = ""
for line in validation_subset:
text += "\t".join(line)
f.write(text)
with open("train_labels.tsv", "w") as f:
text = ""
for line in train_labels:
text += "\t".join(line)
f.write(text)
with open("test_labels.tsv", "w") as f:
text = ""
for line in test_labels:
text += "\t".join(line)
f.write(text)
with open("validation_labels.tsv", "w") as f:
text = ""
for line in validation_labels:
text += "\t".join(line)
f.write(text)