-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcts.py
134 lines (106 loc) · 5.05 KB
/
mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import numpy as np
from model import support_to_scalar, scalar_to_support
import torch
class Node(object):
def __init__(self, prior):
self.visit_count = 0
self.to_play = -1
self.prior = prior
self.value_sum = 0
self.children = {}
self.hidden_state = None
self.reward = 0
def expanded(self):
return len(self.children) > 0
def value(self):
if self.visit_count == 0:
return None
return self.value_sum / self.visit_count
class MinMaxStats(object):
"""A class that holds the min-max values of the tree."""
def __init__(self, known_bounds):
self.maximum = known_bounds.max if known_bounds else -np.inf
self.minimum = known_bounds.min if known_bounds else np.inf
def update(self, value):
self.maximum = max(self.maximum, value)
self.minimum = min(self.minimum, value)
def normalize(self, value):
if value is None:
return 0.0
if self.maximum > self.minimum:
# We normalize only when we have set the maximum and minimum values.
return (value - self.minimum) / (self.maximum - self.minimum)
return value
# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config, root, game, network):
min_max_stats = MinMaxStats(config.known_bounds)
for _ in range(config.num_simulations):
#history = action_history.clone()
node = root
search_path = [node]
while node.expanded():
action, node = select_child(config, node, min_max_stats)
search_path.append(node)
# Inside the search tree we use the dynamics function to obtain the next
# hidden state given an action and the previous hidden state.
parent = search_path[-2]
network_output = network.recurrent_inference(parent.hidden_state, torch.tensor([[action]]))
expand_node(config, node, game.to_play(), game.legal_actions(), network_output)
backpropagate(search_path, support_to_scalar(network_output.value, config.support_size),
game.to_play(), config.discount, min_max_stats)
def select_action(config, num_moves, node, train):
def softmax_sample(visit_counts, actions, t):
counts_exp = np.exp(visit_counts) * (1 / t)
probs = counts_exp / np.sum(counts_exp, axis=0)
action_idx = np.random.choice(len(actions), p=probs)
return actions[action_idx]
visit_counts = [child.visit_count for child in node.children.values()]
actions = [action for action in node.children.keys()]
if train:
t = config.visit_softmax_temperature_fn(
num_moves=num_moves, training_steps=config.get_counter())
action = softmax_sample(visit_counts, actions, t)
else:
action, _ = max(node.children.items(), key=lambda item: item[1].visit_count)
return action
# Select the child with the highest UCB score.
def select_child(config, node, min_max_stats):
_, action, child = max((ucb_score(config, node, child, min_max_stats), action, child) for action, child in node.children.items())
return action, child
# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config, parent, child, min_max_stats):
pb_c = np.log((parent.visit_count + config.pb_c_base + 1) / config.pb_c_base) + config.pb_c_init
pb_c *= np.sqrt(parent.visit_count) / (child.visit_count + 1)
prior_score = pb_c * child.prior
value_score = min_max_stats.normalize(child.value())
return prior_score + value_score
# We expand a node using the value, reward and policy prediction obtained from
# the neural network.
def expand_node(config, node, to_play, actions, network_output):
node.to_play = to_play
node.hidden_state = network_output.hidden_state
node.reward = support_to_scalar(network_output.reward, config.support_size)
policy = {a: np.exp(network_output.policy_logits[0][a]) for a in actions}
policy_sum = sum(policy.values())
for action, p in policy.items():
node.children[action] = Node(p / policy_sum)
# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path, value, to_play, discount, min_max_stats):
for node in search_path:
node.value_sum += value if node.to_play == to_play else -value
node.visit_count += 1
min_max_stats.update(node.value())
value = node.reward + discount * value
# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config, node):
actions = list(node.children.keys())
noise = np.random.dirichlet([config.root_dirichlet_alpha] * len(actions))
frac = config.root_exploration_fraction
for a, n in zip(actions, noise):
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac