Skip to content

Commit

Permalink
Apply fix to issue awentzonline#39
Browse files Browse the repository at this point in the history
  • Loading branch information
rharriso committed Oct 23, 2018
1 parent 0666261 commit 40d0193
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion image_analogy/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
AveragePooling2D, Convolution2D, MaxPooling2D, ZeroPadding2D)
from keras.models import Sequential

from keras import __version__ as keras_version
from distutils.version import StrictVersion

def img_from_vgg(x):
'''Decondition an image from the VGG16 model.'''
Expand Down Expand Up @@ -83,10 +85,26 @@ def get_model(img_width, img_height, weights_path='vgg16_weights.h5', pool_mode=
break
g = f['layer_{}'.format(k)]
weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]

# Check if your version of keras is version '2.0.0' or above.
if StrictVersion(keras_version) >= StrictVersion('2.0.0'):
# If your version of keras is version '2.0.0' or above,
# then
# 1. convert each element x of the list 'weights'
# into a numpy array,
# 2. transpose each of those arrays, and
# 3. save the list of transposed arrays back to
# the 'weights' list.
weights_T = [np.array(x).T for x in weights]
weights = weights_T

# else, leave the 'weights' list unchanged
# leave the elements of 'weights' untransposed.

layer = model.layers[k]
if isinstance(layer, Convolution2D):
weights[0] = np.array(weights[0])[:, :, ::-1, ::-1]
layer.set_weights(weights)

f.close()
return model
return model

0 comments on commit 40d0193

Please sign in to comment.