forked from dmqasc2/RL_pysc2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
66 lines (52 loc) · 2.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from absl import app
from absl import flags
import sys
import torch
from utils import arglist
from runs.minigame import MiniGame
from utils.preprocess import Preprocess
#set the tensor type
torch.set_default_tensor_type('torch.FloatTensor')
#use arglist to set seed
torch.manual_seed(arglist.SEED)
#use flags to set render options
FLAGS = flags.FLAGS
FLAGS(sys.argv)
flags.DEFINE_bool("render", False, "Whether to render with pygame.")
#minigames to use
env_names = ["DefeatZerglingsAndBanelings", "DefeatRoaches",
"CollectMineralShards", "MoveToBeacon", "FindAndDefeatZerglings",
"BuildMarines", "CollectMineralsAndGas"]
rl_algo = 'ppo'
def main(_):
#using all of the minigames
for map_name in env_names:
if rl_algo == 'ddpg':
from agent.ddpg import DDPGAgent
from networks.acnetwork_q_seperated import ActorNet, CriticNet
from utils.memory import Memory
actor = ActorNet()
critic = CriticNet()
memory = Memory(limit=arglist.memory_limit,
action_shape=arglist.action_shape,
observation_shape=arglist.observation_shape)
learner = DDPGAgent(actor, critic, memory)
elif rl_algo == 'ppo':
from agent.ppo import PPOAgent
from networks.acnetwork_v_seperated import ActorNet, CriticNet
from utils.memory import EpisodeMemory
actor = ActorNet()
critic = CriticNet()
memory = EpisodeMemory(limit=arglist.PPO.memory_limit,
action_shape=arglist.action_shape,
observation_shape=arglist.observation_shape)
learner = PPOAgent(actor, critic, memory)
else:
raise NotImplementedError()
preprocess = Preprocess()
#run
game = MiniGame(map_name, learner, preprocess, nb_episodes=10000)
game.run()
return 0
if __name__ == '__main__':
app.run(main)