forked from WantEat-Mao/EdgeAISIM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DQN.py
97 lines (77 loc) · 4.28 KB
/
DQN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64, dropout_prob=0.5):
super(QNetwork, self).__init__()
self.state_dim = state_dim #Assign state dimension
self.action_dim = action_dim #Assign action dimension
self.hidden_dim = hidden_dim #Assign dimension of hidden layer
self.dropout_prob = dropout_prob #Assign dropout proability
self.fc1 = nn.Linear(state_dim, hidden_dim) #Define Q network layers
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, action_dim)
self.dropout1 = nn.Dropout(p=dropout_prob) #Define dropout layers
self.dropout2 = nn.Dropout(p=dropout_prob)
def forward(self, state):
x = torch.relu(self.fc1(state)) #Apply relu activation
x = self.dropout1(x) # Apply dropout
x = torch.relu(self.fc2(x))
x = self.dropout2(x) # Apply dropout
x = torch.relu(self.fc3(x))
q_values = self.fc4(x)
return q_values
class DQNAgent:
def __init__(self, state_dim, action_dim, hidden_dim=64, lr=0.05, gamma=0.99, epsilon=1.0, epsilon_decay=0.997):
self.state_dim = state_dim #Assign state dimension, action dimension and hidden dimension
self.action_dim = action_dim
self.hidden_dim = hidden_dim
self.lr = lr #Initialise hyperparameters
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
# Initialize Q-network and target network
self.q_network = QNetwork(state_dim, action_dim, hidden_dim)
self.target_network = QNetwork(state_dim, action_dim, hidden_dim)
self.target_network.load_state_dict(self.q_network.state_dict()) # Copy Q-network parameters to target network
self.target_network.eval()
self.optimizer = optim.SGD(self.q_network.parameters(), lr=0.3)
# Define loss function for Q-network
self.criterion = nn.MSELoss()
self.loss = None
def choose_action(self, state):
rand_number = np.random.rand()
# print("Random Number : ",rand_number)
# print("Epsilon : ",self.epsilon)
if rand_number < self.epsilon: #If our random number is less than our epsilon, perform random action for exploration
action = np.random.randint(self.action_dim)
# print("random action : ", action)
else:
with torch.no_grad(): #Get the action with the maximum Q value
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = self.q_network(state_tensor)
action = q_values.argmax(dim=1).item()
# print("Q Value action : ", action)
return action
def update(self, state, action, next_state, reward, done, if_dynamic=False):
state_tensor = torch.FloatTensor(state).unsqueeze(0) #Convert numpy array to tensors
action_tensor = torch.FloatTensor([action]).unsqueeze(0)
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
reward_tensor = torch.FloatTensor([reward]).unsqueeze(0)
done_tensor = torch.FloatTensor([done]).unsqueeze(0)
# print(action_tensor)
# Update Q-network
q_values = self.q_network(state_tensor) #Get Q values
q_value = q_values.gather(1, action_tensor.long())
next_q_values = self.target_network(next_state_tensor).detach() #Get Q values for next state
max_next_q_value = next_q_values.max(dim=1)[0]
expected_q_value = reward_tensor + self.gamma * max_next_q_value * (1 - done_tensor) #Calculate expected next Q value, and bacpropogate it through Q network
loss = self.criterion(q_value, expected_q_value)
self.loss = loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def update_target_network(self):
self.target_network.load_state_dict(self.q_network.state_dict())