-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_ai.py
119 lines (84 loc) · 3.32 KB
/
image_ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
import cv2
from collections import Counter
from skimage.color import rgb2lab, deltaE_cie76
import os
#%matplotlib inline
image = cv2.imread('sample_image.jpg')
print("The type of this input is {}".format(type(image)))
print("Shape: {}".format(image.shape))
plt.imshow(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)
"""
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
plt.imshow(gray_image, cmap='gray')
esized_image = cv2.resize(image, (1200, 600))
plt.imshow(resized_image)
"""
resized_image = cv2.resize(image, (1200, 600))
plt.imshow(resized_image)
""" opencv part """
def RGB2HEX(color):
return "#{:02x}{:02x}{:02x}".format(int(color[0]), int(color[1]), int(color[2]))
def get_image(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def get_colors(image, number_of_colors, show_chart):
modified_image = cv2.resize(image, (600, 400), interpolation = cv2.INTER_AREA)
modified_image = modified_image.reshape(modified_image.shape[0]*modified_image.shape[1], 3)
clf = KMeans(n_clusters = number_of_colors)
labels = clf.fit_predict(modified_image)
counts = Counter(labels)
center_colors = clf.cluster_centers_
# We get ordered colors by iterating through the keys
ordered_colors = [center_colors[i] for i in counts.keys()]
hex_colors = [RGB2HEX(ordered_colors[i]) for i in counts.keys()]
rgb_colors = [ordered_colors[i] for i in counts.keys()]
if (show_chart):
plt.figure(figsize = (8, 6))
plt.pie(counts.values(), labels = hex_colors, colors = hex_colors)
return rgb_colors
get_colors(get_image('sample_image.jpg'), 8, True)
IMAGE_DIRECTORY = 'images'
COLORS = {
'GREEN': [0, 128, 0],
'BLUE': [0, 0, 128],
'YELLOW': [255, 255, 0]
}
images = []
for file in os.listdir(IMAGE_DIRECTORY):
if not file.startswith('.'):
images.append(get_image(os.path.join(IMAGE_DIRECTORY, file)))
plt.figure(figsize=(20, 10))
for i in range(len(images)):
plt.subplot(1, len(images), i+1)
plt.imshow(images[i])
def match_image_by_color(image, color, threshold = 60, number_of_colors = 10):
image_colors = get_colors(image, number_of_colors, False)
selected_color = rgb2lab(np.uint8(np.asarray([[color]])))
select_image = False
for i in range(number_of_colors):
curr_color = rgb2lab(np.uint8(np.asarray([[image_colors[i]]])))
diff = deltaE_cie76(selected_color, curr_color)
if (diff < threshold):
select_image = True
return select_image
def show_selected_images(images, color, threshold, colors_to_match):
index = 1
for i in range(len(images)):
selected = match_image_by_color(images[i],
color,
threshold,
colors_to_match)
if (selected):
plt.subplot(1, 5, index)
plt.imshow(images[i])
index += 1
""" GREEN EXAMPLE """
# Search for GREEN
plt.figure(figsize = (20, 10))
show_selected_images(images, COLORS['GREEN'], 60, 5)