-
Notifications
You must be signed in to change notification settings - Fork 18
/
main_code.py
275 lines (234 loc) · 11.6 KB
/
main_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# Author: Aqeel Anwar(ICSRL)
# Created: 5/16/2019, 3:17 PM
# Email: [email protected]
# The code uses already vailable tello_py repository which has been modified to fit the needs of this code
# https://github.com/hanyazou/TelloPy
import traceback
from tellopy.tello_drone import tello_drone
from DeepNet.network.agent import DeepAgent
from DeepNet.network.heat_map import heat_map
from DeepNet.util.Memory import Memory
from RL_functions import *
import numpy as np
from sys import platform
import time
import os, sys
import configparser as cp
from configs.read_cfg import read_cfg
import cv2
from dotmap import DotMap
# Read the configuration file
cfg = read_cfg(verbose=True)
# ---------- Initialize necessary variables
stat_path = 'DeepNet/models/'+cfg.run_name+'/'+cfg.env_type+'/stat.npy'
network_path_half = 'DeepNet/models/'+cfg.run_name+'/'+cfg.env_type+'/'
network_path = network_path_half + '/agent/agent'
data_path = network_path_half+'data_tuple.npy'
input_size = 227
epsilon = 0
action_type = 'Begin'
data_tuple, stat, iteration = load_data(cfg.load_data, cfg.load_data_path)
epi=0
old_state = []
choose = False
reverse_action = [3, 2, 1, 0]
recovering = False
ReplayMemory = Memory(cfg.buffer_len, cfg.load_data, cfg.load_data_path)
action = 0
avoid_action = -1
num_actions = cfg.num_actions
# ---------- Initialize agents
agent = DeepAgent(input_size=input_size,
num_actions=cfg.num_actions,
train_fc='e2e',
name='agent',
env_type='Indoor',
custom_load=cfg.custom_load,
custom_load_path=cfg.custom_load_path,
tensorboard_path = cfg.load_data_path)
target_agent = DeepAgent(input_size=input_size,
num_actions=cfg.num_actions,
train_fc='e2e',
name='target_agent',
env_type='Indoor',
custom_load=cfg.custom_load,
custom_load_path=cfg.custom_load_path,
tensorboard_path = cfg.load_data_path)
# Load heat map network and initliaze the drone connection
DepthAgent = heat_map()
# --------- Initialize drone
drone = tello_drone()
# --------- Initiliaze dict
dict = DotMap()
dict.stat_path = stat_path
dict.network_path = network_path
dict.agent = agent
dict.target_agent = target_agent
dict.data_tuple = data_tuple
dict.data_path = data_path
dict.stat = stat
dict.stat = stat
dict.load_path = cfg.load_data_path
dict.Replay_memory = ReplayMemory
just_begin = True
# I am running the code on two platforms. My MacBook is much slower
# than my GPU installed windows laptop, hence higher skip_frame
if platform == 'win32':
skip_frame = 60
elif platform == 'darwin':
skip_frame = 150
if __name__ == '__main__':
screen = drone.pygame_connect(960, 720)
container, drone_handle = drone.connect()
manual = True
frame_skip = skip_frame
while True:
# flightDataHandler()
try:
for frame in container.decode(video=0):
if 0 < frame_skip:
frame_skip = frame_skip - 1
continue
# print(frame)
else:
start_time = time.time()
frame_skip = skip_frame
# Define control take-over
if manual:
# print("Entering manual mode")
drone.take_action_3(drone_handle, -1)
manual = drone.check_action(drone_handle, manual, dict)
drone.display(frame, manual)
else:
# Check in manual over-ride
manual = drone.if_takeover(drone_handle)
# Update necessary variables here
iteration += 1
drone_stat = drone.get_drone_stats()
dict.iteration = iteration
dict.data_tuple = data_tuple
dict.stat = stat
dict.Replay_memory = ReplayMemory
# Do calculations here
# Display image from front camera
drone.display(frame, manual)
new_state = agent.state_from_frame(frame)
depth_map_3D, depth_float_2D, global_depth = DepthAgent.depth_map_gen(frame)
reward, done = agent.reward_gen(depth_float_2D, depth_map_3D, action, crash_threshold=cfg.crash_thresh, display=True)
# print('Reward is: ', reward)
if not just_begin:
if not recovering: # or remove this condition
data_tuple = []
data_tuple.append([old_state, action, new_state, reward, action_type])
err = get_errors(data_tuple, choose, input_size, agent, target_agent, cfg.Q_clip, cfg.gamma)
ReplayMemory.add(err, data_tuple)
stat.append([iteration, epi, action, action_type, epsilon, reward, cfg.lr])
else:
data_tuple=[]
data_tuple.append([new_state, action, new_state, reward, action_type])
err = get_errors(data_tuple, choose, input_size, agent, target_agent, cfg.Q_clip, cfg.gamma)
ReplayMemory.add(err, data_tuple)
data_tuple = []
data_tuple.append([new_state, 0, new_state, reward, action_type])
err = get_errors(data_tuple, choose, input_size, agent, target_agent, cfg.Q_clip, cfg.gamma)
ReplayMemory.add(err, data_tuple)
if reward == -1:
epi = epi+1
action_type = 'Rcvr'
recovering = True
# reverse action
rev_action = reverse_action[action]
drone.take_action_3(drone_handle, rev_action)
time.sleep(0.7)
drone.take_action_3(drone_handle, -1)
time.sleep(0.4)
# Add data augmentation to tuple
# ----------------------- End of episode ___________________#
else:
recovering = False
# Step 1: Policy evaluation
action, action_type, epsilon = agent.policy(
epsilon=epsilon,
curr_state=new_state,
iter=iteration,
eps_sat=cfg.epsilon_saturation,
eps_model='exponential',
avoid_action=avoid_action)
# action = 1
drone.mark_frame(action, num_actions, frame)
drone.take_action_3(drone_handle, action)
time.sleep(0.8)
drone.take_action_3(drone_handle, -1)
time.sleep(0.4)
# Train if required
if iteration >= cfg.wait_before_train:
if iteration % cfg.update_target_interval == 0:
choose = not choose
agent.save_network(iteration, network_path)
old_states, Qvals, actions, err, idx = minibatch_double(
data_tuple=data_tuple,
batch_size=cfg.batch_size,
choose=choose,
ReplayMemory=ReplayMemory,
input_size=input_size,
agent=agent,
target_agent=target_agent,
Q_clip=cfg.Q_clip,
gamma=cfg.gamma)
for i in range(cfg.batch_size):
ReplayMemory.update(idx[i], err[i])
if choose:
target_agent.train_n(old_states, Qvals, actions, cfg.batch_size, cfg.dropout_rate, cfg.lr, epsilon,
iteration)
else:
agent.train_n(old_states, Qvals, actions, cfg.batch_size, cfg.dropout_rate, cfg.lr,
epsilon,
iteration)
# save network
if iteration % 100 == 0:
print('Saving the learned network...')
np.save(stat_path, stat)
agent.save_network(iteration, network_path)
np.save(data_path, data_tuple)
Memory.save(cfg.load_path)
# # Training
# # Generate Heat Map
# depth_map_3D, depth_float_2D, global_depth = DepthAgent.depth_map_gen(frame, display=False)
# reward, done = agent.reward_gen(depth_float_2D, depth_map_3D, act, crash_threshold=2.0)
# print('Reward is: ', reward)
# reward = gen_reward(heat_map)
# data_tuple([s, a, s_, r])
old_state = new_state
print_action = action
if recovering:
print_action = rev_action
print(
'Iteration: {:>4d} / {:<3d} Action: {:<3d} - {:>4s} Eps: {:<1.4f} Reward: {:>+1.4f} lr: {:>f} len D: {:<5d}'.format(
iteration,
epi,
print_action,
action_type,
epsilon,
reward,
cfg.lr,
len(data_tuple)
)
)
just_begin=False
end_time = time.time()
time_per_iter = end_time-start_time
frame_skip = int(time_per_iter*60)
# print("Frame skip: ", frame_skip)
except Exception as e:
print('------------- Error -------------')
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
print(exc_type, fname, exc_tb.tb_lineno)
print(e)
print(traceback.format_exc())
print('Landing the drone and shutting down')
print('---------------------------------')
drone_handle.land()
time.sleep(5)
drone_handle.quit()
exit(1)