You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
n=0
with tf.Session() as sess:
init=tf.global_variables_initializer()
sess.run(init)
for epoch in range(6):
print("n:",n)
n+=1
for batch in range(n_batch):
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
print("epoch:",epoch)
input=tf.reshape(X,[-1,input_size])
input_rnn=tf.matmul(input,w_in)+b_in
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit])
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32)
综合网上教程,我觉的dynamic_cnn中的input_rnn维度应该是[-1,time_step,input_size],tensorflow中是封装好的(参考:https://www.cnblogs.com/zyly/p/9029591.html),但作者您自己写了input 的w 和b,将input_rnn的维度改成了[-1,time_step,rnn_unit],感觉有点奇怪。
我自己写的代码如下:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)
print("start")
print(tf.version)
print(mnist)
n_inputs=28 #input_size
max_time=28 #也即time_step
lstm_size=100 #num_units
n_classes=10
batch_size=50
n_batch=mnist.train.num_examples //batch_size
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
w=tf.Variable(tf.truncated_normal([lstm_size,n_classes],stddev=0.1))
b=tf.Variable(tf.constant(0.1,shape=[n_classes]))
def RNN(X,w,b):
inputs=tf.reshape(X,[-1,max_time,n_inputs])
lstm_cell=tf.contrib.rnn.BasicLSTMCell(lstm_size)
outputs,final_state=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
#final_state[0]=cell state
#final_state[1]=hidden state
results=tf.nn.softmax(tf.matmul(final_state[1],w)+b)
return results
prediction=RNN(x,w,b)
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #hinton建议设置为1e-3,代表初始学习率
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
n=0
with tf.Session() as sess:
init=tf.global_variables_initializer()
sess.run(init)
for epoch in range(6):
print("n:",n)
n+=1
for batch in range(n_batch):
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
print("epoch:",epoch)
print("over")
The text was updated successfully, but these errors were encountered: