-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
212 lines (168 loc) · 6.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from lpips import LPIPS
from PIL import Image
from torchvision.transforms import Normalize
def show_images_horizontally(
list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False
) -> None:
"""
Visualize the list of images horizontally and save the figure as PNG.
Args:
list_of_files: The list of images as numpy array with shape (N, H, W, C).
output_file: The output file path to save the figure as PNG.
interact: Whether to show the figure interactively in Jupyter Notebook or not in Python.
"""
number_of_files = len(list_of_files)
heights = [a[0].shape[0] for a in list_of_files]
widths = [a.shape[1] for a in list_of_files[0]]
fig_width = 8.0 # inches
fig_height = fig_width * sum(heights) / sum(widths)
# Create a figure with subplots
_, axs = plt.subplots(
1, number_of_files, figsize=(fig_width * number_of_files, fig_height)
)
plt.tight_layout()
for i in range(number_of_files):
_image = list_of_files[i]
axs[i].imshow(_image)
axs[i].axis("off")
# Save the figure as PNG
if interact:
plt.show()
else:
plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25)
def image_grids(images, rows=None, cols=None):
if not images:
raise ValueError("The image list is empty.")
n_images = len(images)
if cols is None:
cols = int(n_images**0.5)
if rows is None:
rows = (n_images + cols - 1) // cols
width, height = images[0].size
grid_width = cols * width
grid_height = rows * height
grid_image = Image.new("RGB", (grid_width, grid_height))
for i, image in enumerate(images):
row, col = divmod(i, cols)
grid_image.paste(image, (col * width, row * height))
return grid_image
def save_image(image: np.array, file_name: str) -> None:
"""
Save the image as JPG.
Args:
image: The input image as numpy array with shape (H, W, C).
file_name: The file name to save the image.
"""
image = Image.fromarray(image)
image.save(file_name)
def load_and_process_images(load_dir: str) -> np.array:
"""
Load and process the images into numpy array from the directory.
Args:
load_dir: The directory to load the images.
Returns:
images: The images as numpy array with shape (N, H, W, C).
"""
images = []
print(load_dir)
filenames = sorted(
os.listdir(load_dir), key=lambda x: int(x.split(".")[0])
) # Ensure the files are sorted numerically
for filename in filenames:
if filename.endswith(".jpg"):
img = Image.open(os.path.join(load_dir, filename))
img_array = (
np.asarray(img) / 255.0
) # Convert to numpy array and scale pixel values to [0, 1]
images.append(img_array)
return images
def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array:
"""
Compute the LPIPS of the input images.
Args:
images: The input images as numpy array with shape (N, H, W, C).
lpips_model: The LPIPS model used to compute perceptual distances.
Returns:
distances: The LPIPS of the input images.
"""
# Get device of lpips_model
device = next(lpips_model.parameters()).device
device = str(device)
# Change the input images into tensor
images = torch.tensor(images).to(device).float()
images = torch.permute(images, (0, 3, 1, 2))
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
images = normalize(images)
# Compute the LPIPS between each adjacent input images
distances = []
for i in range(images.shape[0]):
if i == images.shape[0] - 1:
break
img1 = images[i].unsqueeze(0)
img2 = images[i + 1].unsqueeze(0)
loss = lpips_model(img1, img2)
distances.append(loss.item())
distances = np.array(distances)
return distances
def compute_gini(distances: np.array) -> float:
"""
Compute the Gini index of the input distances.
Args:
distances: The input distances as numpy array.
Returns:
gini: The Gini index of the input distances.
"""
if len(distances) < 2:
return 0.0 # Gini index is 0 for less than two elements
# Sort the list of distances
sorted_distances = sorted(distances)
n = len(sorted_distances)
mean_distance = sum(sorted_distances) / n
# Compute the sum of absolute differences
sum_of_differences = 0
for di in sorted_distances:
for dj in sorted_distances:
sum_of_differences += abs(di - dj)
# Normalize the sum of differences by the mean and the number of elements
gini = sum_of_differences / (2 * n * n * mean_distance)
return gini
def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple:
"""
Compute the smoothness and efficiency of the input images.
Args:
images: The input images as numpy array with shape (N, H, W, C).
lpips_model: The LPIPS model used to compute perceptual distances.
Returns:
smoothness: One minus gini index of LPIPS of consecutive images.
consistency: The mean LPIPS of consecutive images.
max_inception_distance: The maximum LPIPS of consecutive images.
"""
distances = compute_lpips(images, lpips_model)
smoothness = 1 - compute_gini(distances)
consistency = np.mean(distances)
max_inception_distance = np.max(distances)
return smoothness, consistency, max_inception_distance
def separate_source_and_interpolated_images(images: np.array) -> tuple:
"""
Separate the input images into source and interpolated images.
The input source is the start and end of the images, while the interpolated images are the rest.
Args:
images: The input images as numpy array with shape (N, H, W, C).
Returns:
source: The source images as numpy array with shape (2, H, W, C).
interpolation: The interpolated images as numpy array with shape (N-2, H, W, C).
"""
# Check if the array has at least two elements
if len(images) < 2:
raise ValueError("The input array should have at least two elements.")
# Separate the array into two parts
# First part takes the first and last element
source = np.array([images[0], images[-1]])
# Second part takes the rest of the elements
interpolation = images[1:-1]
return source, interpolation