Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
SamaelChen committed Mar 4, 2017
1 parent b6540a2 commit abac400
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mxnet/ex02/handwritten_digit_recogonition.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def to4d(img):
shape = {'data': (batch_size, 1, 28, 28)}
mx.viz.plot_network(symbol=mlp, shape=shape)
logging.getLogger().setLevel(logging.DEBUG)
model = mx.mod.Module(symbol=mlp, context=mx.cpu(), data_names=[
model = mx.mod.Module(symbol=mlp, context=mx.gpu(), data_names=[
'data'], label_names=['softmax_label'])
model.fit(train_data=train_iter, eval_data=val_iter, optimizer='sgd',
optimizer_params={'learning_rate': 0.1}, eval_metric='acc', num_epoch=10)
Expand Down Expand Up @@ -93,7 +93,7 @@ def to4d(img):
mx.viz.plot_network(symbol=lenet, shape=shape)
logging.getLogger().setLevel(logging.DEBUG)
conv_mod = mx.mod.Module(symbol=lenet, data_names=['data'], label_names=[
'softmax_label'], context=mx.cpu())
'softmax_label'], context=mx.gpu())
conv_mod.fit(train_data=train_iter, eval_data=val_iter, optimizer='sgd',
optimizer_params={'learning_rate': 0.1}, eval_metric='acc', num_epoch=10)
prob = conv_mod.predict(val_iter).asnumpy()[0]

0 comments on commit abac400

Please sign in to comment.