-
Notifications
You must be signed in to change notification settings - Fork 20
/
inference_single_image.py
111 lines (86 loc) · 3.82 KB
/
inference_single_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
from __future__ import absolute_import, division, print_function
import os
import argparse
import numpy as np
import PIL.Image as pil
import matplotlib as mpl
import matplotlib.cm as cm
import torch
from torchvision import transforms
import networks
from layers import disp_to_depth
def parse_args():
parser = argparse.ArgumentParser(description='Inference on one Single Image.')
parser.add_argument('--image_path', type=str,
help='path to a test image',
required=True)
parser.add_argument("--load_weights_folder",
type=str,
help="name of model to load",
required=True)
parser.add_argument("--no_cuda",
help='if set, disables CUDA',
action='store_true')
return parser.parse_args()
def prepare_model_for_test(args, device):
model_path = args.load_weights_folder
print("-> Loading model from ", model_path)
encoder_path = os.path.join(model_path, "encoder.pth")
decoder_path = os.path.join(model_path, "depth.pth")
encoder_dict = torch.load(encoder_path, map_location=device)
decoder_dict = torch.load(decoder_path, map_location=device)
encoder = networks.ResnetEncoder(18, False)
decoder = networks.DepthDecoder(
num_ch_enc=encoder.num_ch_enc,
scales=range(1),
)
encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in encoder.state_dict()})
decoder.load_state_dict(decoder_dict)
encoder = encoder.to(device).eval()
decoder = decoder.to(device).eval()
return encoder, decoder, encoder_dict['height'], encoder_dict['width']
def inference(args):
if torch.cuda.is_available() and not args.no_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
encoder, decoder, thisH, thisW = prepare_model_for_test(args, device)
image_path = args.image_path
print("-> Inferencing on image ", image_path)
with torch.no_grad():
# Load image and preprocess
input_image = pil.open(image_path).convert('RGB')
extension = image_path.split('.')[-1]
original_width, original_height = input_image.size
input_image = input_image.resize((thisH, thisW), pil.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
# PREDICTION
input_image = input_image.to(device)
outputs = decoder(encoder(input_image))
disp = outputs[("disp", 0)]
disp_resized = torch.nn.functional.interpolate(
disp, (original_height, original_width), mode="bilinear", align_corners=False)
# Saving numpy file
name_dest_npy = image_path.replace('.'+extension, '_depth.npy')
print("-> Saving depth npy to ", name_dest_npy)
scaled_disp, _ = disp_to_depth(disp, 0.1, 10)
np.save(name_dest_npy, scaled_disp.cpu().numpy())
# Saving colormapped depth image
disp_resized_np = disp_resized.squeeze().cpu().numpy()
vmax = np.percentile(disp_resized_np, 95)
normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
im = pil.fromarray(colormapped_im)
name_dest_im = image_path.replace('.'+extension, '_depth.png')
print("-> Saving depth png to ", name_dest_im)
im.save(name_dest_im)
print('-> Done!')
if __name__ == '__main__':
args = parse_args()
inference(args)