-
Notifications
You must be signed in to change notification settings - Fork 1
/
DGN.py
79 lines (55 loc) · 1.93 KB
/
DGN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import math, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
USE_CUDA = torch.cuda.is_available()
Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)
class Encoder(nn.Module):
def __init__(self, din=32, hidden_dim=128):
super(Encoder, self).__init__()
self.fc = nn.Linear(din, hidden_dim)
def forward(self, x):
embedding = F.relu(self.fc(x))
return embedding
class AttModel(nn.Module):
def __init__(self, n_node, din, hidden_dim, dout):
super(AttModel, self).__init__()
self.fcv = nn.Linear(din, hidden_dim)
self.fck = nn.Linear(din, hidden_dim)
self.fcq = nn.Linear(din, hidden_dim)
self.fcout = nn.Linear(hidden_dim, dout)
self.n_node = n_node
def forward(self, x, mask):
v = F.relu(self.fcv(x))
q = F.relu(self.fcq(x))
k = F.relu(self.fck(x)).permute(0,2,1)
att = F.softmax(torch.mul(torch.bmm(q,k), mask) - 9e15*(1 - mask),dim=2)
v=v.expand([att.size()[0],self.n_node,128])
out = torch.bmm(att,v)
#out = torch.add(out,v)
out = F.relu(self.fcout(out))
return out
class Q_Net(nn.Module):
def __init__(self, hidden_dim, dout):
super(Q_Net, self).__init__()
self.fc = nn.Linear(hidden_dim, dout)
def forward(self, x):
q = self.fc(x)
return q
class DGN(nn.Module):
def __init__(self,n_agent,num_inputs,hidden_dim,num_actions):
super(DGN, self).__init__()
self.encoder = Encoder(num_inputs,hidden_dim)
self.att_1 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim)
self.att_2 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim)
self.q_net = Q_Net(hidden_dim,num_actions)
def forward(self, x, mask):
h1 = self.encoder(x)
h2 = self.att_1(h1, mask)
h3 = self.att_2(h2, mask)
q = self.q_net(h3)
#print(q)
return q