-
Notifications
You must be signed in to change notification settings - Fork 0
/
LunarLander.py
78 lines (68 loc) · 2.14 KB
/
LunarLander.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from typing import List
import gym
from game_env import Game
import numpy as np
from config import MuZeroConfig
class LunarLander(Game):
def __init__(self, action_space_size, discount):
super().__init__(action_space_size, discount)
self.env = gym.make("LunarLander-v2")
self.done = False
def step(self, action):
observation, reward, done, _ = self.env.step(action)
self.done = done
return np.array([[observation]]), reward/3
def terminal(self):
"""Is the game is finished?"""
return self.done
def legal_actions(self):
"""Return the legal actions available at this instant."""
return [i for i in range(4)]
def reset(self):
"""
Reset the game for a new game.
Returns:
Initial observation of the game.
"""
return np.array([[self.env.reset()]])
def close(self):
"""
Properly close the game.
"""
self.env.close()
def make_lunarlander_config():
def visit_softmax_temperature(num_moves, training_steps):
if num_moves < 0.5 * training_steps:
return 1.0
elif num_moves < 0.75 * training_steps:
return 0.5
else:
return 0.25
return MuZeroConfig(
game=LunarLander,
action_space_size=4,
max_moves=500,
discount=0.997,
dirichlet_alpha=0.25,
num_simulations=50,
num_training_loop=50,
num_epochs=200000,
batch_size=32,
td_steps=50,
num_train_episodes=30,
num_eval_episodes=10,
lr_init=0.05,
lr_decay_steps=1000,
max_priority=False,
visit_softmax_temperature_fn=visit_softmax_temperature,
network_args={'support_size': 10,
'encoding_size': 10,
'rep_hidden': [],
'dyn_hidden': [64],
'rew_hidden': [64],
'val_hidden': [64],
'pol_hidden': [],
'observation_shape': (1, 1, 8),
},
result_path="lunarlander.weights"
)