-
Notifications
You must be signed in to change notification settings - Fork 0
/
FrozenLake.py
83 lines (72 loc) · 2.26 KB
/
FrozenLake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import List
import gym
from game_env import Game
import numpy as np
from config import MuZeroConfig
class FrozenLake(Game):
def __init__(self, action_space_size, discount):
super().__init__(action_space_size, discount)
self.env = gym.make("FrozenLake-v0")
self.done = False
def step(self, action):
observation, reward, done, _ = self.env.step(action)
self.done = done
return self.obs_transform(observation), reward
def obs_transform(self, obs):
tmp = np.zeros((1, 1, 16))
tmp[0, 0, obs] = 1
return tmp
def terminal(self):
"""Is the game is finished?"""
return self.done
def legal_actions(self):
"""Return the legal actions available at this instant."""
return [i for i in range(4)]
def reset(self):
"""
Reset the game for a new game.
Returns:
Initial observation of the game.
"""
return self.obs_transform(self.env.reset())
def close(self):
"""
Properly close the game.
"""
self.env.close()
def make_frozenlake_config():
def visit_softmax_temperature(num_moves, training_steps):
if num_moves < 0.5 * training_steps:
return 1.0
elif num_moves < 0.75 * training_steps:
return 0.5
else:
return 0.25
return MuZeroConfig(
game=FrozenLake,
action_space_size=4,
max_moves=20,
discount=0.997,
dirichlet_alpha=0.25,
num_simulations=50,
num_training_loop=50,
num_epochs=200000,
batch_size=32,
td_steps=50,
num_train_episodes=30,
num_eval_episodes=1,
lr_init=0.1,
lr_decay_steps=1000,
max_priority=False,
visit_softmax_temperature_fn=visit_softmax_temperature,
network_args={'support_size': 1,
'encoding_size': 10,
'rep_hidden': [],
'dyn_hidden': [16],
'rew_hidden': [16],
'val_hidden': [16],
'pol_hidden': [],
'observation_shape': (1, 1, 16),
},
result_path="frozenlake.weights"
)