-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_sliding.py
118 lines (93 loc) · 4.17 KB
/
main_sliding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
from pathlib import Path
import cv2
import os
from Source import DenoisingNet, MiniDenoisingNet, DeconvDenoisingNet, InterpolatingDenoisingNet, \
deflatten, threshold, threshold_v2,\
crop, slide,\
reconstruct, reconstruct_sliding,\
write_results, write_info
path = Path()
d = path.resolve()
train_images_path = str(d) + "/Data/train/"
train_images_cleaned_path = str(d) + "/Data/train_cleaned/"
test_path = str(d) + "/Data/test/"
predictions_path = str(d) + "/Predictions/"
sample_path = predictions_path + "sampleSubmission.csv"
demo_path = predictions_path + "demo.csv"
weight_save_path = str(d) + "/weights/model_deconv.ckpt"
mean_path = str(d) + "/mean.npy"
std_path = str(d) + "/std.npy"
weight_load_path = str(d) + "/weights/14/model_deconv.ckpt"
X_train = []
y_train = []
X_test = []
image_width = 420
image_height = 540
mini_img_width = 32
mini_img_height = 32
stride = 16
num_epoch = 0
thres = 0.75
for filename in os.listdir(train_images_path):
image_path = train_images_path + filename
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) / 255
slide(img, X_train, mini_height = mini_img_height, mini_width = mini_img_width, strides = stride)
image_path_y = train_images_cleaned_path + filename
img_y = cv2.imread(image_path_y, cv2.IMREAD_GRAYSCALE) / 255
slide(img_y, y_train, mini_height = mini_img_height, mini_width = mini_img_width, strides = stride)
sub_ind = []
n_subimages = []
image_sizes = []
file_indices = []
for filename in os.listdir(test_path):
ind = str(filename[:-4])
file_indices.append(ind)
image_path = test_path + filename
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) / 255
n_image, indices = slide(img, X_test, mini_height = mini_img_height,
mini_width = mini_img_width, strides = stride, reconstructed = True)
image_sizes.append([img.shape[0], img.shape[1]])
n_subimages.append(n_image)
sub_ind.append(indices)
print("Finish Sliding")
X_train = np.array(X_train).reshape(-1, mini_img_width, mini_img_height, 1)
y_train = np.array(y_train).reshape(-1, mini_img_width * mini_img_height)
X_test = np.array(X_test).reshape(-1, mini_img_width, mini_img_height, 1)
# Zero-center the data:
mean = np.mean(X_train, axis = 0)
X_train = X_train - mean
X_test = X_test - mean
# Standardize the data:
std = np.std(X_train, axis = 0)
X_train = X_train / std
X_test = X_test / std
np.save(mean_path, mean)
np.save(std_path, std)
model = DeconvDenoisingNet(inp_w = mini_img_width, inp_h = mini_img_height)
model.fit(X_train, y_train, num_epoch = num_epoch,
weight_load_path = weight_load_path,
weight_save_path = weight_save_path, print_every = 100
)
predictions = model.predict(X_test)
predictions_reconstructed = reconstruct_sliding(predictions.reshape(-1, mini_img_width, mini_img_height),
image_sizes = image_sizes,
ind_list = sub_ind,
n_subimages = n_subimages,
mini_width = mini_img_width,
mini_height = mini_img_height)
predictions_thresholded = threshold_v2(predictions_reconstructed, threshold = thres)
X_test = X_test * std + mean
X_test_reconstructed = reconstruct_sliding(X_test.reshape(-1, mini_img_width, mini_img_height),
image_sizes = image_sizes,
ind_list = sub_ind,
n_subimages = n_subimages,
mini_width = mini_img_width,
mini_height = mini_img_height)
print("Finish reconstructing")
for ind in range(len(predictions_reconstructed)):
cv2.imwrite(predictions_path + "_slided_predicted_" + str(file_indices[ind]) + ".png", predictions_reconstructed[ind] * 255)
cv2.imwrite(predictions_path + "_slided_original_" + str(file_indices[ind]) + ".png", X_test_reconstructed[ind] * 255)
cv2.imwrite(predictions_path + "_slided_thresholded_" + str(file_indices[ind]) + ".png", predictions_thresholded[ind] * 255)
#
#