-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b47cd11
commit c73f647
Showing
2 changed files
with
72,900 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,386 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:43.218666Z", | ||
"start_time": "2018-09-28T03:18:42.501974Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"import torch.optim as optim\n", | ||
"import torch.autograd as autograd\n", | ||
"import torch.utils.data.dataloader as dataloader\n", | ||
"import random\n", | ||
"import tqdm\n", | ||
"import gensim\n", | ||
"import time\n", | ||
"import math\n", | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:43.236171Z", | ||
"start_time": "2018-09-28T03:18:43.221016Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"with open('poetry.txt') as f:\n", | ||
"# corpus = f.readlines()\n", | ||
" corpus = f.read()\n", | ||
"\n", | ||
"# corpus = list(map(lambda x: x.replace('\\n', ''), corpus))\n", | ||
"corpus = corpus.replace('\\n', ' ').replace('\\r', ' ')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:43.511261Z", | ||
"start_time": "2018-09-28T03:18:43.237889Z" | ||
}, | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"5387" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"idx_to_char = set()\n", | ||
"idx_to_char.update(*corpus)\n", | ||
"idx_to_char = list(idx_to_char)\n", | ||
"char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)])\n", | ||
"vocab_size = len(char_to_idx)\n", | ||
"vocab_size" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:44.771741Z", | ||
"start_time": "2018-09-28T03:18:43.513184Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"corpus_indice = list(map(lambda x: [char_to_idx[c] for c in x], corpus))\n", | ||
"corpus_indice.sort(key=len, reverse=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:44.778084Z", | ||
"start_time": "2018-09-28T03:18:44.774161Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def data_iter_consecutive(corpus_indices, batch_size, num_steps):\n", | ||
" corpus_indices = torch.tensor(corpus_indices)\n", | ||
" data_len = len(corpus_indices)\n", | ||
" batch_len = data_len // batch_size\n", | ||
" indices = corpus_indices[0: batch_size*batch_len].reshape((\n", | ||
" batch_size, batch_len))\n", | ||
" epoch_size = (batch_len - 1) // num_steps\n", | ||
" for i in range(epoch_size):\n", | ||
" i = i * num_steps\n", | ||
" X = indices[:, i: i + num_steps]\n", | ||
" Y = indices[:, i + 1: i + num_steps + 1]\n", | ||
" yield X, Y" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:44.922322Z", | ||
"start_time": "2018-09-28T03:18:44.850504Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# weight = torch.nn.Embedding(vocab_size, 400).weight.data\n", | ||
"weight = torch.diag(torch.ones(vocab_size))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:44.958982Z", | ||
"start_time": "2018-09-28T03:18:44.925456Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"class lyricNet(nn.Module):\n", | ||
" def __init__(self, hidden_dim, embed_dim, num_layers, weight,\n", | ||
" num_labels, bidirectional, dropout=0.5, **kwargs):\n", | ||
" super(lyricNet, self).__init__(**kwargs)\n", | ||
" self.hidden_dim = hidden_dim\n", | ||
" self.embed_dim = embed_dim\n", | ||
" self.num_layers = num_layers\n", | ||
" self.num_labels = num_labels\n", | ||
" self.bidirectional = bidirectional\n", | ||
" if num_layers <= 1:\n", | ||
" self.dropout = 0\n", | ||
" else:\n", | ||
" self.dropout = dropout\n", | ||
" self.embedding = nn.Embedding.from_pretrained(weight)\n", | ||
" self.embedding.weight.requires_grad = False\n", | ||
"# self.embedding = nn.Embedding(num_labels, self.embed_dim)\n", | ||
" self.rnn = nn.GRU(input_size=self.embed_dim, hidden_size=self.hidden_dim,\n", | ||
" num_layers=self.num_layers, bidirectional=self.bidirectional,\n", | ||
" dropout=self.dropout)\n", | ||
" if self.bidirectional:\n", | ||
" self.decoder = nn.Linear(hidden_dim * 2, self.num_labels)\n", | ||
" else:\n", | ||
" self.decoder = nn.Linear(hidden_dim, self.num_labels)\n", | ||
" \n", | ||
" def forward(self, inputs, hidden=None):\n", | ||
" embeddings = self.embedding(inputs)\n", | ||
" states, hidden = self.rnn(embeddings.permute([1, 0, 2]), hidden)\n", | ||
" outputs = self.decoder(states.reshape((-1, states.shape[-1])))\n", | ||
" return(outputs, hidden)\n", | ||
" \n", | ||
" def init_hidden(self, num_layers, batch_size, hidden_dim, **kwargs):\n", | ||
" hidden = torch.zeros(num_layers, batch_size, hidden_dim)\n", | ||
" return hidden" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:45.027897Z", | ||
"start_time": "2018-09-28T03:18:44.961491Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def predict_rnn(prefix, num_chars, model, device, idx_to_char, char_to_idx):\n", | ||
" output = [char_to_idx[prefix[0]]]\n", | ||
" hidden = torch.zeros(num_layers, 1, hidden_dim)\n", | ||
" if use_gpu:\n", | ||
" hidden = hidden.to(device)\n", | ||
" for t in range(num_chars + len(prefix) - 1):\n", | ||
" X = torch.tensor([output[-1]]).reshape((1, 1))\n", | ||
" if use_gpu:\n", | ||
" X = X.to(device)\n", | ||
" pred, hidden = model(X, hidden)\n", | ||
" if t < len(prefix) - 1:\n", | ||
" output.append(char_to_idx[prefix[t + 1]])\n", | ||
" elif pred.argmax(dim=1) == vocab_size:\n", | ||
" break\n", | ||
" else:\n", | ||
" output.append(int(pred.argmax(dim=1)))\n", | ||
" return(''.join([idx_to_char[i] for i in output]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:45.110903Z", | ||
"start_time": "2018-09-28T03:18:45.030487Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"embedding_dim = 400\n", | ||
"hidden_dim = 256\n", | ||
"lr = 1e2\n", | ||
"momentum = 0.0\n", | ||
"num_epoch = 100\n", | ||
"use_gpu = True\n", | ||
"num_layers = 1\n", | ||
"bidirectional = False\n", | ||
"batch_size = 32\n", | ||
"device = torch.device('cuda:0')\n", | ||
"loss_function = nn.CrossEntropyLoss()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:47.799222Z", | ||
"start_time": "2018-09-28T03:18:45.113296Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"model = lyricNet(hidden_dim=hidden_dim, embed_dim=vocab_size, num_layers=num_layers,\n", | ||
" num_labels=vocab_size, weight=weight, bidirectional=bidirectional)\n", | ||
"optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)\n", | ||
"if use_gpu:\n", | ||
"# model = nn.DataParallel(model)\n", | ||
" model.to(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:18:47.819097Z", | ||
"start_time": "2018-09-28T03:18:47.801577Z" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"'春燋瀰洎表燋恃恃瀰洎燋恃'" | ||
] | ||
}, | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"predict_rnn('春', 11, model, device, idx_to_char, char_to_idx)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2018-09-28T03:32:45.552517Z", | ||
"start_time": "2018-09-28T03:18:47.820970Z" | ||
}, | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"epoch 1/100, loss 5.7401, norm 0.1910, time 7.659s, since 0h 0m 7s\n", | ||
"春,风人人人。 不日不不远,相人有不人。 不日不不远,相人有不人。 不日不不远,相人有不人。 不\n", | ||
"夏,风人人人。 不日不不远,相人有不人。 不日不不远,相人有不人。 不日不不远,相人有不人。 不\n", | ||
"epoch 11/100, loss 4.1388, norm 0.3108, time 8.270s, since 0h 1m 31s\n", | ||
"春风吹早晚,人语不相见。 何事不相见,空山有所思。 不知何处去,不是有时情。 不是愁人意,相逢不\n", | ||
"夏云中路远,何处是愁人。 不见南山色,相逢白发生。 不知今日日,不是故人情。 不见东风水,相逢不\n", | ||
"epoch 21/100, loss 3.7344, norm 0.4362, time 8.435s, since 0h 2m 55s\n", | ||
"春风吹早晚,独有清风起。 不见君王孙,年年不堪把。 春风生,有人意不如何人。 一点玉壶中,一枝枝\n", | ||
"夏云无处所。 一径入幽林,千峰挂石床。 白云生白发,青鸟入窗枝。 清景不可得,清风何处归。 今日\n", | ||
"epoch 31/100, loss 3.5179, norm 0.4746, time 8.451s, since 0h 4m 18s\n", | ||
"春风吹满衣。 一点白头翁,数茎红叶红。 夜久不相见,空歌声满长。 病起见庭花,今年独自迟。 不知\n", | ||
"夏云生故国。 日暮天下寒,风波上林晚。 烟波生处尽,风景澹凄凄。 此地堪相望,知君道不平。 万里\n", | ||
"epoch 41/100, loss 3.3755, norm 0.5808, time 8.465s, since 0h 5m 42s\n", | ||
"春风雨夜长。 不知楼上月,时得似君王。 白发今如此,青云自可亲。 不知名利利,何事更伤心。 自古\n", | ||
"夏云无处所。 一为一瓢酒,还似江上月。 何人知此来,独向何处去。 何事东归来,空山有秋色。 心知\n", | ||
"epoch 51/100, loss 3.2706, norm 0.6426, time 8.457s, since 0h 7m 6s\n", | ||
"春风雨夜行。 旧业今朝夕,清江上苑阳。 寒山对孤屿,秋水正秋风。 何必沧浪水,相看楚客还。 一身\n", | ||
"夏云无处所。 何人知此去,见古今古今。 白日照我心,白云自相逢。 不知何所有,唯有旧时心。 一片\n", | ||
"epoch 61/100, loss 3.1858, norm 0.6814, time 8.304s, since 0h 8m 30s\n", | ||
"春风雨夜多时歇,林下独行迟。 上苑春风满,郎衣早晚来。 江流无限迹,云外有阳春。 自有归心去,频\n", | ||
"夏云满路无。 粉壁空悬影,红蕉满长枝。 粉壁开青镜,罗裙织白头。 白头梳未织,香发自相逢。 不是\n", | ||
"epoch 71/100, loss 3.1145, norm 0.7103, time 8.374s, since 0h 9m 54s\n", | ||
"春风雨夜行人,西风吹玉树秋。 我来头上去,莫使日光催。 何事来相问,同游不自愁。 征人望不同,落\n", | ||
"夏云无限情。 白头梳白头,白头梳头梳。 自怜无定日,何必有风尘。 不是长安隐,同年不得归。 春风\n", | ||
"epoch 81/100, loss 3.0533, norm 0.6437, time 8.329s, since 0h 11m 18s\n", | ||
"春风雨夜行。 一枝如有酒,一夜不闻蝉。 莫怪频相识,空馀道者情。 地僻春草长,路转不堪思。 只是\n", | ||
"夏云无处所。 何人知此路,不是旧山色。 岁晏若可见。 何事不相见,见君白发新。 君看白发口,已是\n", | ||
"epoch 91/100, loss 3.0009, norm 0.6567, time 8.325s, since 0h 12m 42s\n", | ||
"春风雨夜多。 独有高楼月,偏宜落照尘。 不知何处士,本是自相逢。 泉石生苔藓,禅僧挂竹间。 此时\n", | ||
"夏云生石棱,何人更相见。 月出照关山,风光满城阙。 幽人闭门巷,静语对幽谷。 采药去幽人,空来此\n", | ||
"epoch 100/100, loss 2.9597, norm 0.8618, time 8.424s, since 0h 13m 57s\n", | ||
"春风雨夜清。 二毛松上月,一日此中天。 白发今还此,清秋日暮春。 寒光照不极,鸟影自无闲。 何处\n", | ||
"夏云无限情,江月正相和。 柳色临流水,花阴暗暮天。 何人知此景,几处谢公卿。 白发生涯尽,青山一\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"since = time.time()\n", | ||
"for epoch in range(num_epoch):\n", | ||
" start = time.time()\n", | ||
" num, total_loss = 0, 0\n", | ||
"# if epoch == 5000:\n", | ||
"# optimizer.param_groups[0]['lr'] = lr * 0.1\n", | ||
" data = data_iter_consecutive(corpus_indice, batch_size, 35)\n", | ||
" hidden = model.init_hidden(num_layers, batch_size, hidden_dim)\n", | ||
" for X, Y in data:\n", | ||
" num += 1\n", | ||
" hidden.detach_()\n", | ||
" if use_gpu:\n", | ||
" X = X.to(device)\n", | ||
" Y = Y.to(device)\n", | ||
" hidden = hidden.to(device)\n", | ||
" optimizer.zero_grad()\n", | ||
" output, hidden = model(X, hidden)\n", | ||
" l = loss_function(output, Y.t().reshape((-1,)))\n", | ||
" l.backward()\n", | ||
" norm = nn.utils.clip_grad_norm_(model.parameters(), 1e-2)\n", | ||
" optimizer.step()\n", | ||
" total_loss += l.item()\n", | ||
" end = time.time()\n", | ||
" s = end - since\n", | ||
" h = math.floor(s / 3600)\n", | ||
" m = s - h * 3600\n", | ||
" m = math.floor(m / 60)\n", | ||
" s -= m * 60\n", | ||
" if (epoch % 10 == 0) or (epoch == (num_epoch - 1)):\n", | ||
" print('epoch %d/%d, loss %.4f, norm %.4f, time %.3fs, since %dh %dm %ds'\n", | ||
" %(epoch+1, num_epoch, total_loss / num, norm, end-start, h, m, s))\n", | ||
" print(predict_rnn('春', 47, model, device, idx_to_char, char_to_idx))\n", | ||
" print(predict_rnn('夏', 47, model, device, idx_to_char, char_to_idx))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.