-
Notifications
You must be signed in to change notification settings - Fork 83
/
run_rnn.py
35 lines (24 loc) · 996 Bytes
/
run_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
from deepnet.nnet import RNN
from deepnet.solver import sgd_rnn
def text_to_inputs(path):
"""
Converts the given text into X and y vectors
X : contains the index of all the characters in the text vocab
y : y[i] contains the index of next character for X[i] in the text vocab
"""
with open(path) as f:
txt = f.read()
X, y = [], []
char_to_idx = {char: i for i, char in enumerate(set(txt))}
idx_to_char = {i: char for i, char in enumerate(set(txt))}
X = np.array([char_to_idx[i] for i in txt])
y = [char_to_idx[i] for i in txt[1:]]
y.append(char_to_idx['.'])
y = np.array(y)
vocab_size = len(char_to_idx)
return X, y, vocab_size, char_to_idx, idx_to_char
if __name__ == "__main__":
X, y, vocab_size, char_to_idx, idx_to_char = text_to_inputs('data/Rnn.txt')
rnn = RNN(vocab_size,vocab_size,char_to_idx,idx_to_char)
rnn = sgd_rnn(rnn,X,y,10,10,0.1)