This fork of google-deepmind/mctx introduces a new feature used in AlphaZero: continuing search from a subtree of a previous state's search output, or subtree persistence.
This allows Monte Carlo Tree Search to continue from an already-initialized, partially populated search tree. This lets work done in a previous call to Monte Carlo Tree Search persist to the next call, avoiding lots of repeated work!
mctx-az introduces a new policy: alphazero_policy
which allows the user to pass a pre-initialized Tree
to continue the search with.
Then, get_subtree
can be used to extract the subtree rooted at a particular child node of the root, corresponding to a taken action.
In cases where the search tree should not be saved, such as an episdoe terminating, reset_search_tree
can be used to clear the tree.
In order to initialize a new tree, pass tree=None
, to alphazero_policy
, along with max_nodes
to specify the capacity of the tree, which in most cases
should be >= num_simulations
.
alphazero_policy
otherwise functions exactly the same as muzero_policy
.
policy_output = mctx.alphazero_policy(params, rng_key, root, recurrent_fn,
num_simulations=32, tree=None, max_nodes=48)
tree = policy_output.search_tree
# get chosen action from policy output
action = policy_output.action
# extract the subtree corresponding to the chosen action
tree = mctx.get_subtree(tree, action)
# go to next environment state
env_state = env.step(env_state, action)
# reset the search tree where the environment has terminated
tree = mctx.reset_search_tree(tree, env_state.terminated)
# new search with subtree
# (max_nodes has no effect when a tree is passed)
policy_ouput = mctx.alphazero_policy(params, rng_key, root, recurrent_fn,
num_simulations=32, tree=tree)
A call to any mctx policy will expand num_simulations
nodes (assuming max_depth
is not breached).
Given that alphazero_policy
accepts a pre-populated Tree
, it is possible that there will not be enough
room left for num_simulations
new nodes.
In the case where a tree is full, values and visit counts are still propagated backwards to all nodes along the visit path as they would if the expansion was in bounds. However, a new node is not created and stored in the search tree, only its in-bounds predecessors are updated.
The mctx readme links to a simple Connect4 example: https://github.com/Carbon225/mctx-classic
I modified this example to demonstrate the use of alphazero_policy
and get_subtree
. You can see it here
If you run into problems or need help, please create an Issue and I will do my best to assist you promptly.