Skip to content

Commit

Permalink
相似图片搜索
Browse files Browse the repository at this point in the history
  • Loading branch information
zhai_pro committed Jan 13, 2019
1 parent a34fb8b commit 0343c5f
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ docs/_build/
target/

*.jpg
*.png
*.npy
*.txt
*.npz
*.pkl
*.h5
24 changes: 24 additions & 0 deletions category_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np

import mlearn
from pretreatment import load_data


def learn():
texts, imgs = load_data()
labels = mlearn.predict(texts)
labels = labels.argmax(axis=1)
imgs.dtype = np.uint64
imgs.shape = (-1, 8)
unique_imgs = np.unique(imgs)
print(unique_imgs.shape)
imgs_labels = []
for img in unique_imgs:
idxs = np.where(imgs == img)[0]
counts = np.bincount(labels[idxs], minlength=80)
imgs_labels.append(counts)
np.savez('images.npz', images=unique_imgs, labels=imgs_labels)


if __name__ == '__main__':
learn()
11 changes: 8 additions & 3 deletions mlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,19 @@ def main():
model.save('model.h5')


def predict():
def predict(texts):
from keras import models
model = models.load_model('model.h5')
texts = np.load('data.npy')
texts = texts / 255.0
_, h, w = texts.shape
texts.shape = (-1, h, w, 1)
labels = model.predict(texts)
return labels


def _predict():
texts = np.load('data.npy')
labels = predict(texts)
np.save('labels.npy', labels)


Expand All @@ -75,5 +80,5 @@ def show():

if __name__ == '__main__':
main()
predict()
_predict()
show()
22 changes: 13 additions & 9 deletions pretreatment.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,26 @@ def get_imgs(img):
def pretreat():
if not os.path.isdir(PATH):
download_images()
imgs = []
texts, imgs = [], []
for img in os.listdir(PATH):
img = os.path.join(PATH, img)
img = cv2.imread(img, cv2.IMREAD_GRAYSCALE)
imgs.append(get_text(img))
return imgs
texts.append(get_text(img))
imgs.append(get_imgs(img))
return texts, imgs


def load_data(path='data.npy'):
def load_data(path='data.npz'):
if not os.path.isfile(path):
imgs = pretreat()
np.save(path, imgs)
return np.load(path)
texts, imgs = pretreat()
np.savez(path, texts=texts, images=imgs)
f = np.load(path)
return f['texts'], f['images']


if __name__ == '__main__':
imgs = load_data()
texts, imgs = load_data()
print(texts.shape)
print(imgs.shape)
cv2.imwrite('temp.jpg', imgs[0])
imgs = imgs.reshape(-1, 8)
print(np.unique(imgs, axis=0).shape)

0 comments on commit 0343c5f

Please sign in to comment.