Skip to content

Commit

Permalink
Fix for issue awentzonline#39. Transform weights from shape (64, 3, 3…
Browse files Browse the repository at this point in the history
…, 3) to (3, 3, 3, 64) for keras versions 2.0.0 and above.
  • Loading branch information
zkneupper committed Oct 17, 2017
1 parent 0666261 commit a28f01e
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions image_analogy/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
AveragePooling2D, Convolution2D, MaxPooling2D, ZeroPadding2D)
from keras.models import Sequential

from keras import __version__ as keras_version
from distutils.version import StrictVersion

def img_from_vgg(x):
'''Decondition an image from the VGG16 model.'''
Expand Down Expand Up @@ -83,6 +85,22 @@ def get_model(img_width, img_height, weights_path='vgg16_weights.h5', pool_mode=
break
g = f['layer_{}'.format(k)]
weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]

# Check if your version of keras version '2.0.0' or above.
if StrictVersion(keras_version) >= StrictVersion('2.0.0'):
# If your version of keras version '2.0.0' or above,
# then
# 1. convert each element x of the list 'weights'
# into a numpy array,
# 2. transpose each of those arrays, and
# 3. save the list of transposed arrays back to
# the 'weights' list.
weights_T = [np.array(x).T for x in weights]
weights = weights_T

# else, leave the 'weights' list unchanged
# leave the elements of 'weights' untransposed.

layer = model.layers[k]
if isinstance(layer, Convolution2D):
weights[0] = np.array(weights[0])[:, :, ::-1, ::-1]
Expand Down

0 comments on commit a28f01e

Please sign in to comment.