-
Notifications
You must be signed in to change notification settings - Fork 0
/
selfplay.py
executable file
·145 lines (119 loc) · 5.61 KB
/
selfplay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import chess
import concurrent.futures
from collections import Counter
from utils import convert_result_string_to_value
from nnmodel import NNModel
from player import ModelPlayer, TablebasePlayer
def simulate_game_from_position(start_position, white_player, black_player, backtrack_with_learning_signal):
# Generating game
game = chess.Board(start_position)
while not game.is_game_over():
if game.turn == chess.WHITE:
next_move = white_player.get_next_move(game.fen())
else:
next_move = black_player.get_next_move(game.fen())
game.push(next_move)
result = game.result()
signal = convert_result_string_to_value(result)
decay = 0.99
# Backtracking with learning signal
fen_signal_list = []
if backtrack_with_learning_signal:
while len(game.move_stack) > 0:
if game.turn == chess.WHITE:
fen_signal_list.append("{},{}".format(game.fen(), signal))
else:
fen_signal_list.append("{},{}".format(game.mirror().fen(), -signal))
signal = signal * decay
game.pop()
# At last, add starting position with weak signal
fen_signal_list.append("{},{}".format(game.fen(), signal))
return convert_result_string_to_value(result), fen_signal_list
def simulate_games(start_position, simulations, white_model_path, black_model_path, exploration,
output_file='/dev/null', random_seed=None, multi_process=False):
if 'tablebases' in white_model_path:
white_tablebase_player = TablebasePlayer(white_model_path)
else:
white_tablebase_player = None
white_model = NNModel(random_seed=random_seed, model_path=white_model_path)
if 'tablebases' in black_model_path:
black_tablebase_player = TablebasePlayer(black_model_path)
else:
black_tablebase_player = None
black_model = NNModel(random_seed=random_seed, model_path=black_model_path)
if output_file.startswith('/dev/null'):
# save computation time
backtrack_with_learning_signal = False
else:
backtrack_with_learning_signal = True
# ThreadPoolExecutor() would result in race conditions to the np.random object
# resulting in non-deterministic outputs between simulations for fixed random seed
with concurrent.futures.ProcessPoolExecutor() as ppe:
futures = list()
for i in range(simulations):
player_random_seed = random_seed + i if random_seed else None
white_player = white_tablebase_player or ModelPlayer(white_model, exploration, player_random_seed)
black_player = black_tablebase_player or ModelPlayer(black_model, exploration, player_random_seed)
if multi_process:
futures.append(ppe.submit(simulate_game_from_position, start_position, white_player, black_player,
backtrack_with_learning_signal))
else:
futures.append(simulate_game_from_position(start_position, white_player, black_player,
backtrack_with_learning_signal))
with open(output_file, 'w') as file:
results = []
thing_to_loop_over = concurrent.futures.as_completed(futures) if multi_process else futures
for future in thing_to_loop_over:
if multi_process:
result, selfplay_training_data = future.result()
else:
result, selfplay_training_data = future
results.append(result)
for line in selfplay_training_data:
file.write(line)
file.write("\n")
return results
def main(arg1, arg2, arg3, arg4, arg5, arg6=None):
try:
output_games_file = os.path.abspath(arg1)
nr_of_simulations = int(arg2)
white_path = os.path.abspath(arg3)
black_path = os.path.abspath(arg4)
exploration = float(arg5)
random_seed = int(arg6) if arg6 else None
except ValueError as e:
print("Second parameter (simulations) should be an integer and the fith (exploration) a float.")
print(e)
raise
# hardcoded starting position
KQ_vs_K = "8/8/8/3k4/8/3KQ3/8/8 w - - 0 1"
if 'tablebases' in white_path:
print("White player: '{}'".format(white_path))
else:
white_model = NNModel(random_seed=random_seed, model_path=white_path)
print("White player: '{}'".format('random' if white_model.is_random() else white_path))
if 'tablebases' in black_path:
print("Black player: '{}'".format(black_path))
else:
black_model = NNModel(random_seed=random_seed, model_path=black_path)
print("Black player: '{}'".format('random' if black_model.is_random() else black_path))
print("(random_seed: {}, exploration: {})".format(random_seed, exploration))
print("Simulating {} games from position".format(nr_of_simulations))
print(chess.Board(KQ_vs_K))
print()
# Temporarily disabling multi-processing of game simulation because of an unknown bug
multi_process = False
results = simulate_games(KQ_vs_K, nr_of_simulations, white_path, black_path, exploration, output_games_file,
random_seed, multi_process)
counter = Counter(results)
print(" -> White wins: {}".format(counter[1]))
print(" -> Draws: {}".format(counter[0]))
print(" -> Black wins: {}".format(counter[-1]))
if __name__ == '__main__':
import sys
args = sys.argv
if len(args) <= 6:
main(args[1], args[2], args[3], args[4], args[5])
else:
main(args[1], args[2], args[3], args[4], args[5], args[6])