forked from sttkm/POET-Evogym
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_ppo_diff.py
113 lines (89 loc) · 3.57 KB
/
run_ppo_diff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import csv
import time
import numpy as np
import torch
from ppo import Policy, PPO
from gym_utils import make_vec_envs
import ppo_config as default_config
def evaluate(policy, envs, num_eval=1, deterministic=False):
obs = envs.reset()
episode_rewards = []
while len(episode_rewards) < num_eval:
with torch.no_grad():
action = policy.predict(obs, deterministic=deterministic)
obs, _, done, infos = envs.step(action)
for info in infos:
if 'episode' in info:
episode_rewards.append(info['episode']['r'])
return np.mean(episode_rewards)
def run_ppo(env_id, robot, train_iters, eval_interval, save_file, model, config=None, deterministic=False, save_iter=False, history_file=None):
if config is None:
config = default_config
train_envs = make_vec_envs(env_id, robot, config.seed, config.num_processes, gamma=config.gamma, vecnormalize=True)
train_envs.obs_rms = model[1]
eval_envs = make_vec_envs(env_id, robot, config.seed, config.eval_processes, gamma=None, vecnormalize=True)
eval_envs.training = False
policy = Policy(
train_envs.observation_space,
train_envs.action_space,
init_log_std=config.init_log_std,
device='cpu'
)
policy.load_state_dict(model[0])
algo = PPO(
policy,
train_envs,
learning_rate=config.learning_rate,
n_steps=config.steps,
batch_size=config.steps*config.num_processes//config.num_mini_batch,
n_epochs=config.epochs,
gamma=config.gamma,
gae_lambda=config.gae_lambda,
clip_range=config.clip_range,
clip_range_vf=config.clip_range,
normalize_advantage=True,
ent_coef=config.ent_coef,
vf_coef=config.vf_coef,
max_grad_norm=config.max_grad_norm,
device='cpu',
lr_decay=config.lr_decay,
max_iter=train_iters*10)
if save_iter:
interval = time.time()
torch.save([policy.state_dict(), train_envs.obs_rms], os.path.join(save_file, '0.pt'))
history_header = ['iteration', 'reward']
items = {
'iteration': 0,
'reward': 0
}
with open(history_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=history_header)
writer.writeheader()
writer.writerow(items)
max_reward = float('-inf')
for iter in range(train_iters):
algo.step()
if (iter+1) % eval_interval == 0:
eval_envs.obs_rms = train_envs.obs_rms.copy()
reward = evaluate(policy, eval_envs, num_eval=config.eval_processes, deterministic=deterministic)
if reward > max_reward:
max_reward = reward
if not save_iter:
torch.save([policy.state_dict(), train_envs.obs_rms], save_file + '.pt')
if save_iter:
now = time.time()
log_std = policy.log_std.mean()
print(f'iteration: {iter+1:=5} elapsed times: {now-interval:.3f} reward: {reward:6.3f} log_std: {log_std:.5f}')
interval = now
torch.save([policy.state_dict(), train_envs.obs_rms], os.path.join(save_file, f'{iter+1}.pt'))
items = {
'iteration': iter+1,
'reward': reward
}
with open(history_file, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=history_header)
writer.writerow(items)
train_envs.close()
eval_envs.close()
return max_reward