-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
50 lines (42 loc) · 1.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gym_2048
import gym
import torch
from dqn_model import DQN
from new_action import make_action
import numpy as np
if __name__ == '__main__':
env = gym.make('2048-v0')
env.seed()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
online_net = DQN(size_state=4, num_actions=4)
online_net.load_state_dict(torch.load('./Double_dqn/Past_results/best_model.pth'))
online_net.eval()
obs = env.reset()
env.render()
done = False
moves = 0
count_action = 0
old_action = 0
episode = 0
max_values = []
while episode < 1000:
action = env.np_random.choice(range(4), 1).item()
if old_action == action:
count_action +=1
if count_action > 10:
action = env.np_random.choice(range(4), 1).item()
count_action = 0
next_state, reward, done, info = env.step(action)
moves += 1
obs = next_state
old_action = action
if done:
max_values.append(obs.max())
obs = env.reset()
episode += 1
print('Next Action: "{}"\n\nReward: {}'.format(
gym_2048.Base2048Env.ACTION_STRING[action], reward))
env.render()
with open('values_state_random.npy', 'wb') as f:
np.save(f, max_values)
print('\nTotal Moves: {}'.format(moves))