-
Notifications
You must be signed in to change notification settings - Fork 0
/
enjoy.py
58 lines (48 loc) · 1.85 KB
/
enjoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gym
import time
from stable_baselines3 import PPO
import os
import numpy as np
from envs.quad import QuadEnv
from gym import envs as gym_envs
import sys
from utils import denormalize
if __name__ == '__main__':
if len(sys.argv)<2:
raise os.error('ckpt path not provided')
checkpoint_file_name=sys.argv[1]
quad_instance = QuadEnv()
episode_duration=5 #secs
step_freq = quad_instance.step_freq()
steps_per_episode=int(episode_duration*step_freq)
gym_envs.register(
id="Quad-v0",
entry_point="envs.quad:QuadEnv",
max_episode_steps=steps_per_episode,
reward_threshold=0.0,
)
#### Show, record a video, and log the model's performance #
quad_env = gym.make('Quad-v0')
data_dir=os.path.join(os.path.dirname(__file__), 'data')
tuned_model_path = os.path.join(data_dir, checkpoint_file_name)
tuned_model = PPO.load(tuned_model_path)
start_time = time.time()
obs = quad_env.reset()
# render freq. and control freq. is equal!!!
sim_duration = 5 # secs
while True:
for i in range(sim_duration*quad_instance.step_freq()):
action_normed, _states = tuned_model.predict(obs,
deterministic=True # OPTIONAL 'deterministic=False'
)
action=denormalize(action_normed, quad_instance._get_action_space())
#print('action: ', action)
obs, reward, done, info = quad_env.step(action_normed)
quad_env.render()
elapsed_real = time.time() - start_time
elapsed_sim = i*(1./quad_instance.step_freq())
time.sleep(max(elapsed_sim - elapsed_real, 0.))
if done:
obs = quad_env.reset() # OPTIONAL EPISODE HALT
break
quad_env.close()