forked from THUNLP-AIPoet/StylisticPoetry
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Yun.py
152 lines (136 loc) · 5 KB
/
Yun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# -*- coding: utf-8 -*-
import numpy as np
import pickle
class Yun():
def __init__(self):
print("loading YunList.txt")
self.yun_dic = {}
#self.count = 0
f = open("data/poemcorpus/YunList.txt",'r')
lines = f.readlines()
for line in lines:
line = line.strip().split(' ')
for i in range(len(line)):
line[i] = line[i]
if line[0] not in self.yun_dic:
self.yun_dic.update({
line[0]:[line[1]]
})
else:
if not line[1] in self.yun_dic[line[0]]:
self.yun_dic[line[0]].append(line[1])
self.poemyun = {}
self.mulword_map = {}
self.word_map = {}
# please download the pingshui_amb.pkl file from
# https://github.com/THUNLP-AIPoet/Datasets/tree/master/CRRD
# and then move it into the data/other/ dir
fyun = open("data/other/pingshui_amb.pkl", "rb")
self.word_map = pickle.load(fyun,encoding='utf8')
self.mulword_map = pickle.load(fyun,encoding='utf8')
self.poemyun = pickle.load(fyun,encoding='utf8')
fyun.close()
def getBatchYun(self,batchSen,ivocab, PAD_ID): #return a matrix [none,30+1]
numLen = len(batchSen)
numBatch = batchSen[0].shape[0]
batchYun = np.zeros((numBatch,30+1),dtype='float32')
for i in range(numBatch):
tmpsen = []
for j in range(numLen):
if j==numLen-1 or j==0 or batchSen[j][i]>=PAD_ID: # go </s>
continue
tmpsen.append(ivocab[batchSen[j][i]])
#print(tmpsen)
tmpsen=''.join(tmpsen)
yun = self.getYun(tmpsen)
#print(yun)
if int(yun[0])<0:
batchYun[i,0] = 1.0
else:
batchYun[i,int(yun[0])] = 1.0
return batchYun
def getYun(self, sen):
if sen in self.poemyun:
return self.poemyun[sen]
last_word = sen[len(sen)-1]
if last_word in self.word_map:
twoword = sen[-2]+sen[-1]
twoword = twoword
#print twoword
if twoword in self.mulword_map:
return self.mulword_map[twoword]
threeword = sen[-3]+sen[-2]+sen[-1]
threeword = threeword
#print threeword
if threeword in self.mulword_map:
return self.mulword_map[threeword]
#print last_word
return self.word_map[last_word]
elif last_word in self.yun_dic:
return self.yun_dic[last_word]
else:
#self.count += 1
return ['0']
def updateyun(self, sen, yun):
def update_mulword(mulword):
mulword = mulword
if mulword in self.mulword_map:
if not yun in self.mulword_map[mulword]:
self.mulword_map[mulword].append(yun)
else:
self.mulword_map.update({
mulword:[yun]
})
word = sen[-1]
word = word
yun = yun[0]
if word in self.yun_map:
self.yun_map[word][yun] +=1
twoword = sen[-2]+sen[-1]
update_mulword(twoword)
threeword = sen[-3]+sen[-2]+sen[-1]
update_mulword(threeword)
def totalpoemlist(self):
# load the corpus to make sure the model won't generate existing sentences
f = open("data/poemcorpus/totaljiantipoems_change.txt",'r')
lines = f.readlines()
f.close()
poemyun = {}
title = ""
author = ""
dynasty = ""
sen_list = []
count = 0
for line in lines:
line = line.strip().split(" ")
if line[0] == "Title":
L = len(sen_list)
yun_list = []
for i in range(L):
yun_list.append(self.getYun(sen_list[i]))
if L>=1:
tmp = yun_list[1]
for i in range(L/2):
tmp = [val for val in tmp if val in yun_list[i*2+1]]
else:
tmp = []
if len(tmp) > 1:
tmp = [val for val in tmp if val in yun_list[0]]
if len(tmp) == 1 and tmp[0] != "-1": # []:50366 ["-1"]: 7104 25967
for i in range(L/2):
yun_list[i*2+1] = tmp
self.updateyun(sen_list[i*2+1],tmp)
tmp = [val for val in tmp if val in yun_list[0]]
if len(tmp)>0:
yun_list[0] = tmp
self.updateyun(sen_list[0],tmp)
for i in range(L):
poemyun.update({
sen_list[i]:yun_list[i]
})
sen_list = []
else:
sen_list.append(line[0])
self.poemyun = poemyun
if __name__ == "__main__":
yun = Yun()