-
Notifications
You must be signed in to change notification settings - Fork 0
/
Agent.py
108 lines (90 loc) · 4.39 KB
/
Agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from copy import deepcopy
from typing import List
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.optim import Adam
class Agent:
"""single agent in MADDPG"""
def __init__(self, obs_dim, act_dim, global_obs_dim, actor_lr, critic_lr, device):
"""
create one of the agents in MADDPG
:param obs_dim: observation dimension of the current agent, i.e. local observation space
:param act_dim: action dimension of the current agent, i.e. local action space
:param global_obs_dim: input dimension of the global critic of the current agent, if there are
3 agents for example, the input for global critic is (obs1, obs2, obs3, act1, act2, act3)
"""
# the actor output logit of each action
self.actor = MLPNetwork(obs_dim, act_dim).to(device)
# critic input all the states and actions
self.critic = MLPNetwork(global_obs_dim, 1).to(device)
self.actor_optimizer = Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = Adam(self.critic.parameters(), lr=critic_lr)
self.target_actor = deepcopy(self.actor).to(device)
self.target_critic = deepcopy(self.critic).to(device)
self.device = device
@staticmethod
def gumbel_softmax(logits, tau=1.0, eps=1e-20):
# NOTE that there is a function like this implemented in PyTorch(torch.nn.functional.gumbel_softmax),
# but as mention in the doc, it may be removed in the future, so i implement it myself
epsilon = torch.rand_like(logits)
logits += -torch.log(-torch.log(epsilon + eps) + eps)
return F.softmax(logits / tau, dim=-1)
def action(self, obs, *, model_out=False):
"""
choose action according to given `obs`
:param model_out: the original output of the actor, i.e. the logits of each action will be
`gumbel_softmax`ed by default(model_out=False) and only the action will be returned
if set to True, return the original output of the actor and the action
"""
# this method is called in the following two cases:
# a) interact with the environment
# b) calculate action when update actor, where input(obs) is sampled from replay buffer with size:
# torch.Size([batch_size, state_dim])
logits = self.actor(obs) # torch.Size([batch_size, action_size])
action = self.gumbel_softmax(logits)
if model_out:
return action, logits
return action
def target_action(self, obs):
# when calculate target critic value in MADDPG,
# we use target actor to get next action given next states,
# which is sampled from replay buffer with size torch.Size([batch_size, state_dim])
logits = self.target_actor(obs) # torch.Size([batch_size, action_size])
action = self.gumbel_softmax(logits)
return action.squeeze(0).detach()
def critic_value(self, state_list: List[Tensor], act_list: List[Tensor]):
x = torch.cat(state_list + act_list, 1)
return self.critic(x).squeeze(1) # tensor with a given length
def target_critic_value(self, state_list: List[Tensor], act_list: List[Tensor]):
x = torch.cat(state_list + act_list, 1)
return self.target_critic(x).squeeze(1) # tensor with a given length
def update_actor(self, loss):
self.actor_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
self.actor_optimizer.step()
def update_critic(self, loss):
self.critic_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.critic_optimizer.step()
class MLPNetwork(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=64, non_linear=nn.ReLU()):
super(MLPNetwork, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
non_linear,
nn.Linear(hidden_dim, hidden_dim),
non_linear,
nn.Linear(hidden_dim, out_dim),
).apply(self.init)
@staticmethod
def init(m):
"""init parameter of the module"""
gain = nn.init.calculate_gain('relu')
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=gain)
m.bias.data.fill_(0.01)
def forward(self, x):
return self.net(x)