From 0343c5fb4cb5c6387581bd1dea764f2a413bbd16 Mon Sep 17 00:00:00 2001 From: zhai_pro Date: Sun, 13 Jan 2019 12:11:58 +0800 Subject: [PATCH] =?UTF-8?q?=E7=9B=B8=E4=BC=BC=E5=9B=BE=E7=89=87=E6=90=9C?= =?UTF-8?q?=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 ++ category_images.py | 24 ++++++++++++++++++++++++ mlearn.py | 11 ++++++++--- pretreatment.py | 22 +++++++++++++--------- 4 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 category_images.py diff --git a/.gitignore b/.gitignore index d275fe7..7a7bc5d 100644 --- a/.gitignore +++ b/.gitignore @@ -54,7 +54,9 @@ docs/_build/ target/ *.jpg +*.png *.npy *.txt *.npz *.pkl +*.h5 diff --git a/category_images.py b/category_images.py new file mode 100644 index 0000000..9688d25 --- /dev/null +++ b/category_images.py @@ -0,0 +1,24 @@ +import numpy as np + +import mlearn +from pretreatment import load_data + + +def learn(): + texts, imgs = load_data() + labels = mlearn.predict(texts) + labels = labels.argmax(axis=1) + imgs.dtype = np.uint64 + imgs.shape = (-1, 8) + unique_imgs = np.unique(imgs) + print(unique_imgs.shape) + imgs_labels = [] + for img in unique_imgs: + idxs = np.where(imgs == img)[0] + counts = np.bincount(labels[idxs], minlength=80) + imgs_labels.append(counts) + np.savez('images.npz', images=unique_imgs, labels=imgs_labels) + + +if __name__ == '__main__': + learn() diff --git a/mlearn.py b/mlearn.py index 0af1aca..67f5850 100644 --- a/mlearn.py +++ b/mlearn.py @@ -51,14 +51,19 @@ def main(): model.save('model.h5') -def predict(): +def predict(texts): from keras import models model = models.load_model('model.h5') - texts = np.load('data.npy') texts = texts / 255.0 _, h, w = texts.shape texts.shape = (-1, h, w, 1) labels = model.predict(texts) + return labels + + +def _predict(): + texts = np.load('data.npy') + labels = predict(texts) np.save('labels.npy', labels) @@ -75,5 +80,5 @@ def show(): if __name__ == '__main__': main() - predict() + _predict() show() diff --git a/pretreatment.py b/pretreatment.py index 0da6943..408db1f 100644 --- a/pretreatment.py +++ b/pretreatment.py @@ -60,22 +60,26 @@ def get_imgs(img): def pretreat(): if not os.path.isdir(PATH): download_images() - imgs = [] + texts, imgs = [], [] for img in os.listdir(PATH): img = os.path.join(PATH, img) img = cv2.imread(img, cv2.IMREAD_GRAYSCALE) - imgs.append(get_text(img)) - return imgs + texts.append(get_text(img)) + imgs.append(get_imgs(img)) + return texts, imgs -def load_data(path='data.npy'): +def load_data(path='data.npz'): if not os.path.isfile(path): - imgs = pretreat() - np.save(path, imgs) - return np.load(path) + texts, imgs = pretreat() + np.savez(path, texts=texts, images=imgs) + f = np.load(path) + return f['texts'], f['images'] if __name__ == '__main__': - imgs = load_data() + texts, imgs = load_data() + print(texts.shape) print(imgs.shape) - cv2.imwrite('temp.jpg', imgs[0]) + imgs = imgs.reshape(-1, 8) + print(np.unique(imgs, axis=0).shape)