forked from starry-sky6688/MARL-Algorithms
-
Notifications
You must be signed in to change notification settings - Fork 1
/
runner.py
128 lines (108 loc) · 5.64 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import os
from common.rollout import RolloutWorker, CommRolloutWorker
from agent.agent import Agents, CommAgents
from common.replay_buffer import ReplayBuffer
import matplotlib.pyplot as plt
class Runner:
def __init__(self, env, args):
self.env = env
if args.alg.find('commnet') > -1 or args.alg.find('g2anet') > -1: # communication agent
self.agents = CommAgents(args)
self.rolloutWorker = CommRolloutWorker(env, self.agents, args)
else: # no communication agent
self.agents = Agents(args)
self.rolloutWorker = RolloutWorker(env, self.agents, args)
if not args.evaluate and args.alg.find('coma') == -1 and args.alg.find('central_v') == -1 and args.alg.find('reinforce') == -1: # these 3 algorithms are on-poliy
self.buffer = ReplayBuffer(args)
self.args = args
self.win_rates = []
self.episode_rewards = []
# 用来保存plt和pkl
self.save_path = self.args.result_dir + '/' + args.alg + '/' + args.map
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
def run(self, num):
time_steps, train_steps, evaluate_steps = 0, 0, -1
while time_steps < self.args.n_steps:
print('Run {}, time_steps {}'.format(num, time_steps))
if time_steps // self.args.evaluate_cycle > evaluate_steps:
# evaluate_cycle: 5000
if self.args.matrix_game == True:
if self.args.alg.find('task_decomposition_all')>-1:
self.evaluate_matrix(time_steps)
elif self.args.alg.find('qmix')>-1:
self.evaluate_matrix(time_steps)
else:
win_rate, episode_reward = self.evaluate()
print('win_rate is ', win_rate)
print('episode reward: ', episode_reward)
self.win_rates.append(win_rate)
self.episode_rewards.append(episode_reward)
self.plt(num)
evaluate_steps += 1
episodes = []
# 收集self.args.n_episodes(1)个episodes
for episode_idx in range(self.args.n_episodes):
episode, _, _, steps = self.rolloutWorker.generate_episode(episode_idx)
episodes.append(episode)
time_steps += steps
# episode的每一项都是一个(1, episode_len, n_agents, 具体维度)四维数组.但如s等与agent无关的,仅有三维(1, episode_len, 具体维度)
# 下面要把所有episode的的key拼在一起
episode_batch = episodes[0]
episodes.pop(0)
for episode in episodes: #这里似乎永远不会进去?考虑加一个assert 1==0
for key in episode_batch.keys():
episode_batch[key] = np.concatenate((episode_batch[key], episode[key]), axis=0)
if self.args.alg.find('coma') > -1 or self.args.alg.find('central_v') > -1 or self.args.alg.find('reinforce') > -1:
self.agents.train(episode_batch, train_steps, self.rolloutWorker.epsilon)
train_steps += 1
elif self.args.alg.find('task_decomposition') > -1:
self.buffer.store_episode(episode_batch)
for train_step in range(self.args.train_steps):
mini_batch = self.buffer.sample(min(self.buffer.current_size, self.args.batch_size))
self.agents.train(mini_batch, train_steps, self.rolloutWorker.epsilon)
train_steps += 1
else:
self.buffer.store_episode(episode_batch)
for train_step in range(self.args.train_steps):
mini_batch = self.buffer.sample(min(self.buffer.current_size, self.args.batch_size))
self.agents.train(mini_batch, train_steps)
train_steps += 1
if self.args.matrix_game == False:
win_rate, episode_reward = self.evaluate()
print('win_rate is ', win_rate)
print('reward is ', episode_reward)
self.win_rates.append(win_rate)
self.episode_rewards.append(episode_reward)
self.plt(num)
else:
self.evaluate_matrix(time_steps)
def evaluate(self):
win_number = 0
episode_rewards = 0
for epoch in range(self.args.evaluate_epoch):
_, episode_reward, win_tag, _ = self.rolloutWorker.generate_episode(epoch, evaluate=True)
episode_rewards += episode_reward
if win_tag:
win_number += 1
return win_number / self.args.evaluate_epoch, episode_rewards / self.args.evaluate_epoch
def evaluate_matrix(self, tim_steps):
print(tim_steps)
_, episode_reward, win_tag, _ = self.rolloutWorker.generate_episode(0, evaluate=True)
def plt(self, num):
plt.figure()
plt.ylim([0, 105])
plt.cla()
plt.subplot(2, 1, 1)
plt.plot(range(len(self.win_rates)), self.win_rates)
plt.xlabel('step*{}'.format(self.args.evaluate_cycle))
plt.ylabel('win_rates')
plt.subplot(2, 1, 2)
plt.plot(range(len(self.episode_rewards)), self.episode_rewards)
plt.xlabel('step*{}'.format(self.args.evaluate_cycle))
plt.ylabel('episode_rewards')
plt.savefig(self.save_path + '/plt_{}.png'.format(num), format='png')
np.save(self.save_path + '/win_rates_{}'.format(num), self.win_rates)
np.save(self.save_path + '/episode_rewards_{}'.format(num), self.episode_rewards)
plt.close()