-
Notifications
You must be signed in to change notification settings - Fork 2
/
player.py
47 lines (35 loc) · 1.88 KB
/
player.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import agents
CHECKPOINTS = 'checkpoints'
class Player(object):
def __init__(self, env, monitor='output/', seed=None):
self.env = env
self.agents = {'universe': agents.A3C(env, monitor+'universe/',
CHECKPOINTS+'/universe/'+env+'/', 1),
'tensorpack': agents.TPAgent(env, monitor+'tensorpack/',
CHECKPOINTS+'/tensorpack/'+env+'/'+env, 1),
'random': agents.RandomAgent()}
self.seed = seed
self.best = ''
def choose(self, num_episodes_eval=100):
scores = {}
for agent in self.agents.keys():
scores[agent] = self.agents[agent].play(num_episodes_eval,
env=self.env,
record=False,
seed=self.seed)
self.best = max(scores, key=scores.get)
def choose_and_record(self, num_episodes_eval=100, num_episodes_run=100):
scores = {}
for agent in self.agents.keys():
scores[agent] = self.agents[agent].play(num_episodes_eval,
env=self.env,
record=False,
seed=self.seed)
self.best = max(scores, key=scores.get)
self.agents[max(scores, key=scores.get)].play(num_episodes_run,
env=self.env,
record=True,
seed=self.seed)
def upload(self, outputm, api_key=''):
self.agents[max(scores, key=scores.get)].do_submit(output, api_key)