-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
357 lines (309 loc) · 14.1 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# 使用act网络进行模仿学习
import os
import time
import pickle
import argparse
from copy import deepcopy
import ctypes
import sys
import cv2
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from einops import rearrange
from utils import set_seed, load_data_test, compute_dict_mean, detach_dict
from constants import TASK_CONFIG, STATE_DIM, SN_idx2key
from policy import ACTPolicy, CNNMLPPolicy
from gi_env import GIDataEnv, GIRealEnv
from config import infer_configs
import IPython
e = IPython.embed
device = "cuda" if torch.cuda.is_available() else "cpu"
def is_admin():
try:
return ctypes.windll.shell32.IsUserAnAdmin()
except:
return False
def main(args):
# the very first mom to get other parameters
config_name = args['config_name']
config = infer_configs[config_name]
# basic config in config.py
set_seed(config['seed'])
global device
device = config['device'] if device == 'cuda' else 'cpu'
ckpt_dir = config['ckpt_dir']
policy_class = config['policy_class']
task_name = config['task_name']
chunk_size = config['chunk_size']
backbone = config['backbone']
ckpt_name = config['ckpt_name']
real_GI = config['real_GI']
save_video = config['save_video']
onscreen_render = config['onscreen_render']
temporal_agg = config['temporal_agg']
# get task parameters
task_config = TASK_CONFIG[task_name]
dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len'] # TODO: rm this config
camera_names = task_config['camera_names']
# fixed parameters
state_dim = STATE_DIM
# TODO: make these configurable
# ACT parameters
enc_layers = 4
dec_layers = 7
nheads = 8
policy_config = {
'chunk_size' : chunk_size,
'hidden_dim' : config['hidden_dim'],
'dim_feedforward' : config['dim_feedforward'],
'backbone' : backbone,
'enc_layers' : enc_layers,
'dec_layers' : dec_layers,
'nheads' : nheads,
'camera_names' : camera_names,
'device' : device,
}
config = {
'ckpt_dir' : ckpt_dir,
'ckpt_name' : ckpt_name,
'num_episodes' : num_episodes,
'episode_len' : episode_len,
'state_dim' : state_dim,
'policy_class' : policy_class,
'onscreen_render' : onscreen_render,
'policy_config' : policy_config,
'task_name' : task_name,
'seed' : config['seed'],
'temporal_agg' : temporal_agg,
'camera_names' : camera_names,
'real_GI' : real_GI,
'save_video' : save_video,
'video_dir' : config['video_dir'],
}
if save_video and not os.path.exists(config['video_dir']):
os.makedirs(config['video_dir'])
# test !
test_bc(config)
def test_bc(config):
print(config)
set_seed(config['seed'])
ckpt_dir = config['ckpt_dir']
ckpt_name = config['ckpt_name']
state_dim = config['state_dim']
onscreen_render = config['onscreen_render']
policy_config = config['policy_config']
camera_names = config['camera_names']
task_name = config['task_name']
max_timestamps = config['episode_len']
max_episodes = config['num_episodes']
temporal_agg = config['temporal_agg']
save_video = config['save_video']
real_GI = config['real_GI']
# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
if config['policy_class'] == 'mlp':
policy = CNNMLPPolicy(policy_config)
elif config['policy_class'] == 'act':
policy = ACTPolicy(policy_config)
loading_status =policy.load_state_dict(torch.load(ckpt_path))
print(f'loading status: {loading_status}')
policy = policy.to(device)
policy.eval()
print(f'policy loaded from {ckpt_path}')
# load dataset stats
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
# 状态已经是0-1,无需预处理
# 这玩意是不是应该交给env?
pre_process = lambda state: state
def post_process(action):
# 反归一 dx dy
action[-2:] = action[-2:] * (stats['mouse_action_max'] - stats['mouse_action_min']) + stats['mouse_action_min']
# 对于 keyboard action,sigmoid到概率
action[:-2] = 1 / (1 + np.exp(-action[:-2]))
return action
print(f'stats: {stats}')
# load env
if real_GI:
env = GIRealEnv(config)
else:
env = GIDataEnv(config)
# TODO: make query freq 独立于 chunk size,使得在使用时间集成时可以延迟几个ts
# TODO: 研究独立后的query freq对于成功率的影响
query_frequecy = policy_config['chunk_size']
if temporal_agg:
query_frequecy = 1
chunk_size = policy_config['chunk_size'] # TODO: explain this
# TODO: make this scale configurable
max_timestamps = int(max_timestamps * 1)
for episode_id in range(max_episodes):
print(f'episode {episode_id}')
# reset env
env.reset()
if onscreen_render:
cv2.namedWindow('image', cv2.WINDOW_NORMAL)
if temporal_agg:
# max_ts, max_ts+chunk size, state_dim
# 记录推理的actions ?
all_time_actions = torch.zeros([max_timestamps, max_timestamps+chunk_size, state_dim])
all_time_actions = all_time_actions.to(device)
# IN Yaa, the state is the mskb
# 同时引入假设,初始为 [0] * state_dim
# 所以需要在推理的时候,维护一个 state
state_history = torch.zeros([1, max_timestamps, state_dim])
state_history = state_history.to(device)
# mksb state
# 最后两个为鼠标 dx, dy。无需考虑状态
curr_state = np.zeros([state_dim]) # easier for cpu operations
image_list = [] # for video?
state_list = []
target_state_list = []
if save_video:
video_path = os.path.join(config['video_dir'], f'{task_name}_{ckpt_name}_{episode_id}.mp4')
# TODO: fix the hardcode frame size
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (640, 480))
with torch.inference_mode():
for t in range(max_timestamps):
obs = env.observation()
# print(curr_state)
# state 需要 preprocess
state_numpy = deepcopy(curr_state)
state = pre_process(state_numpy)
# make it to [1, state_dim]
state = torch.from_numpy(state).float().unsqueeze(0)
state = state.to(device)
state_history[0, t] = state
# get image from obs
image_dict, ground_action = obs
# print(image_dict.keys())
curr_image = image_preprocess(image_dict, camera_names)
ground_action = [0]*state_dim if ground_action is None else ground_action
# torch.Size([1, 2, 3, 480, 640])
# print(curr_image.shape)
# feed image & state to policy
if t % query_frequecy == 0:
# predict 一个 action chunk
# 1, chunk size, state_dim
t0 = time.time()
print(np.round(state_numpy, 2))
all_actions = policy(state, curr_image)
print(f'policy cost time: {time.time() - t0}')
# 进行时间集成!
if temporal_agg:
all_time_actions[[t], t:t+chunk_size] = all_actions
actions_for_curr_step = all_time_actions[:, t]
actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
actions_for_curr_step = actions_for_curr_step[actions_populated]
# TODO: make it configurable
k = 0.01
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
exp_weights = exp_weights / exp_weights.sum()
exp_weights = torch.from_numpy(exp_weights).unsqueeze(dim=1)
exp_weights = exp_weights.to(device)
raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
else:
# 就是每隔 chunk size推理一次
raw_action = all_actions[:, t % query_frequecy]
# 需要后处理 action
# sigmoid 一下
# raw_action = torch.sigmoid(raw_action)
# 需要后处理 action
# sigmoid 一下
# raw_action = torch.sigmoid(raw_action)
raw_action = raw_action.squeeze(0).cpu().numpy()
action = post_process(raw_action)
# 需要把action到离散状态
# TODO: make threshold configurable
# 把 action 中 < min_thre 的部分置为 0, > max_thre 的部分置为 1
min_thre = 0.5
max_thre = 0.5
action_bin = np.zeros_like(action, dtype=np.int8)
action_bin[action < min_thre] = 0
action_bin[action >= max_thre] = 1
target_state = action
# print(np.max(action), np.min(action))
# action影响curr_state
# 获得发生变动的键盘状态,进而获得实际 人的动作
human_actions = []
human_actions_gt = []
show_updown = False
for state_id in range(state_dim-2):
if show_updown:
if action_bin[state_id] == curr_state[state_id]:
continue
else:
human_action = f"{SN_idx2key[state_id]} {'up' if action_bin[state_id] == 0 else 'down'}"
if human_action[0] == ' ':
human_action = 'sp' + human_action[1:]
# print(f'append in {state_id}')
human_actions.append(human_action)
else:
if action_bin[state_id]:
human_action = f"{SN_idx2key[state_id]}"
human_actions.append(human_action)
if ground_action[state_id]:
human_action = f"{SN_idx2key[state_id]}"
human_actions_gt.append(human_action)
curr_state[state_id] = action_bin[state_id]
dx, dy = action[-2], action[-1]
dx_gt, dy_gt = ground_action[-2], ground_action[-1]
print(human_actions)
print(human_actions_gt)
# do nothing in data env actually
if not env.step(action):
break
# for visualization
state_list.append(state_numpy)
target_state_list.append(target_state)
# update curr frame
# TODO: plot dx dy in frame
image_dict = env.render()
# 先把图像拼接起来
curr_image = np.concatenate([image_dict[cam_name] for cam_name in camera_names], axis=1)
# frame 上显示 str
kb_events_str = ','.join(human_actions)
kb_events_str_gt = ','.join(human_actions_gt)
curr_image = cv2.putText(curr_image, kb_events_str, (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 0, 255), 3, cv2.LINE_AA)
curr_image = cv2.putText(curr_image, kb_events_str_gt, (10, 200), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 255, 0), 3, cv2.LINE_AA)
# 在frame 中间绘制一个箭头,表示dx dy
dx, dy = int(dx), int(dy)
dx_gt, dy_gt = int(dx_gt), int(dy_gt)
curr_image = cv2.arrowedLine(curr_image, (320, 240), (320+dx, 240+dy), (0, 0, 255), 2)
# 在旁边绘制另一个 predicted dx dy
curr_image = cv2.arrowedLine(curr_image, (320, 240), (320+dx, 240+dy), (0, 0, 255), 2)
curr_image = cv2.arrowedLine(curr_image, (320, 240), (320+dx_gt, 240+dy_gt), (0, 255, 0), 2)
if onscreen_render:
cv2.imshow('image', curr_image)
cv2.waitKey(1)
if save_video:
out.write(curr_image)
print(f'step {t}/{max_timestamps}')
def image_preprocess(image_dict, camera_names):
# 图像进入推理前的预处理
# srds,在policy中还有一个normalize to ImageNet分布的操作
curr_images = []
for cam_name in camera_names:
curr_image = rearrange(image_dict[cam_name], 'h w c -> c h w')
curr_images.append(curr_image)
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().unsqueeze(0)
curr_image = curr_image.to(device)
return curr_image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config_name', '-c',
action='store', type=str,
help='which config to be used',
choices=infer_configs.keys(),
required=True)
if is_admin():
main(vars(parser.parse_args()))
else:
# 以管理员权限重新运行程序,同时传递参数
ctypes.windll.shell32.ShellExecuteW(None, "runas", sys.executable, " ".join(sys.argv), None, 1)
sys.exit(0)