-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_runner.py
67 lines (55 loc) · 1.71 KB
/
test_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import multiprocessing
import os
from typing import Dict, List
import ray
import tensorflow as tf
from mcts import MonteCarloTreeSearch, Node
def _test(
config: "MuZeroConfig",
render: bool,
network: "MuZeroNetwork",
storage: "CentralActorStorage",
save_path: bool,
ep_idx: int,
save_video: bool = False
):
env = config.new_game(
save_video=save_video,
save_path=save_path,
uid=ep_idx,
video_callable=lambda episode_id: True,
)
done = False
ep_reward = 0
obs = env.reset()
if network is None:
network = config.get_init_network_obj(training=False)
network.set_params(ray.get(storage.get_params.remote()))
while not done:
if render:
env.render()
root = Node(0)
obs = tf.expand_dims(tf.convert_to_tensor(obs, dtype=tf.float32), axis=0)
output = network.initial_inference(obs, config.value_support)
root.expand(env.to_play(), env.legal_actions(), output)
MonteCarloTreeSearch(config).run(root, network, env.action_history())
action, _ = root.select_action(temperature=1, random=False)
obs, reward, done, _ = env.step(action.index)
ep_reward += reward
env.close()
return ep_reward
def test(
config: "MuZeroConfig",
num_episodes: int,
render: bool,
network: "MuZeroNetwork" = None,
storage: "CentralActorStorage" = None,
save_video: bool = False,
) -> float:
save_path = os.path.join(config.exp_path, "recordings")
ep_reward = 0
for ep_idx in range(num_episodes):
ep_reward += _test(
config, render, network, storage, save_path, ep_idx, save_video
)
return ep_reward / num_episodes