forked from AI4Finance-Foundation/ElegantRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
erl_config.py
142 lines (119 loc) · 6.98 KB
/
erl_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import gymnasium as gym
import torch as th
import numpy as np
class Config:
def __init__(self, agent_class=None, env_class=None, env_args=None):
self.agent_class = agent_class # agent = agent_class(...)
self.if_off_policy = self.get_if_off_policy() # whether off-policy or on-policy of DRL algorithm
self.env_class = env_class # env = env_class(**env_args)
self.env_args = env_args # env = env_class(**env_args)
if env_args is None: # dummy env_args
env_args = {'env_name': None, 'state_dim': None, 'action_dim': None, 'if_discrete': None}
self.env_name = env_args['env_name'] # the name of environment. Be used to set 'cwd'.
self.state_dim = env_args['state_dim'] # vector dimension (feature number) of state
self.action_dim = env_args['action_dim'] # vector dimension (feature number) of action
self.if_discrete = env_args['if_discrete'] # discrete or continuous action space
'''Arguments for reward shaping'''
self.gamma = 0.99 # discount factor of future rewards
self.reward_scale = 1.0 # an approximate target reward usually be closed to 256
'''Arguments for training'''
self.net_dims = [64, 32] # the middle layer dimension of MLP (MultiLayer Perceptron)
self.learning_rate = 6e-5 # 2 ** -14 ~= 6e-5
self.soft_update_tau = 5e-3 # 2 ** -8 ~= 5e-3
if self.if_off_policy: # off-policy
self.batch_size = int(64) # num of transitions sampled from replay buffer.
self.horizon_len = int(512) # collect horizon_len step while exploring, then update network
self.buffer_size = int(1e6) # ReplayBuffer size. First in first out for off-policy.
self.repeat_times = 1.0 # repeatedly update network using ReplayBuffer to keep critic's loss small
else: # on-policy
self.batch_size = int(128) # num of transitions sampled from replay buffer.
self.horizon_len = int(2048) # collect horizon_len step while exploring, then update network
self.buffer_size = None # ReplayBuffer size. Empty the ReplayBuffer for on-policy.
self.repeat_times = 8.0 # repeatedly update network using ReplayBuffer to keep critic's loss small
'''Arguments for device'''
self.gpu_id = int(0) # `int` means the ID of single GPU, -1 means CPU
self.thread_num = int(8) # cpu_num for pytorch, `torch.set_num_threads(self.num_threads)`
self.random_seed = int(0) # initialize random seed in self.init_before_training()
'''Arguments for evaluate'''
self.cwd = None # current working directory to save model. None means set automatically
self.if_remove = True # remove the cwd folder? (True, False, None:ask me)
self.break_step = +np.inf # break training if 'total_step > break_step'
self.eval_times = int(32) # number of times that get episodic cumulative return
self.eval_per_step = int(1e4) # evaluate the agent per training steps
def init_before_training(self):
np.random.seed(self.random_seed)
th.manual_seed(self.random_seed)
th.set_num_threads(self.thread_num)
th.set_default_dtype(th.float32)
if self.cwd is None: # set cwd (current working directory) for saving model
self.cwd = f'./{self.env_name}_{self.agent_class.__name__[5:]}_{self.random_seed}'
if self.if_remove is None: # remove or keep the history files
self.if_remove = bool(input(f"| Arguments PRESS 'y' to REMOVE: {self.cwd}? ") == 'y')
if self.if_remove:
import shutil
shutil.rmtree(self.cwd, ignore_errors=True)
print(f"| Arguments Remove cwd: {self.cwd}")
else:
print(f"| Arguments Keep cwd: {self.cwd}")
os.makedirs(self.cwd, exist_ok=True)
def get_if_off_policy(self) -> bool:
agent_name = self.agent_class.__name__ if self.agent_class else ''
on_policy_names = ('SARSA', 'VPG', 'A2C', 'A3C', 'TRPO', 'PPO', 'MPO')
return all([agent_name.find(s) == -1 for s in on_policy_names])
def get_gym_env_args(env, if_print: bool) -> dict:
"""Get a dict ``env_args`` about a standard OpenAI gym env information.
param env: a standard OpenAI gym env
param if_print: [bool] print the dict about env information.
return: env_args [dict]
env_args = {
'env_name': env_name, # [str] the environment name, such as XxxXxx-v0
'state_dim': state_dim, # [int] the dimension of state
'action_dim': action_dim, # [int] the dimension of action or the number of discrete action
'if_discrete': if_discrete, # [bool] action space is discrete or continuous
}
"""
if {'unwrapped', 'observation_space', 'action_space', 'spec'}.issubset(dir(env)): # isinstance(env, gym.Env):
env_name = env.unwrapped.spec.id
state_shape = env.observation_space.shape
state_dim = state_shape[0] if len(state_shape) == 1 else state_shape # sometimes state_dim is a list
if_discrete = isinstance(env.action_space, gym.spaces.Discrete)
if if_discrete: # make sure it is discrete action space
action_dim = env.action_space.n
elif isinstance(env.action_space, gym.spaces.Box): # make sure it is continuous action space
action_dim = env.action_space.shape[0]
if any(env.action_space.high - 1):
print('WARNING: env.action_space.high', env.action_space.high)
if any(env.action_space.low + 1):
print('WARNING: env.action_space.low', env.action_space.low)
else:
raise RuntimeError('\n| Error in get_gym_env_info(). Please set these value manually:'
'\n `state_dim=int; action_dim=int; if_discrete=bool;`'
'\n And keep action_space in range (-1, 1).')
else:
env_name = env.env_name
state_dim = env.state_dim
action_dim = env.action_dim
if_discrete = env.if_discrete
env_args = {'env_name': env_name,
'state_dim': state_dim,
'action_dim': action_dim,
'if_discrete': if_discrete, }
if if_print:
env_args_str = repr(env_args).replace(',', f",\n{'':11}")
print(f"env_args = {env_args_str}")
return env_args
def kwargs_filter(function, kwargs: dict) -> dict:
import inspect
sign = inspect.signature(function).parameters.values()
sign = {val.name for val in sign}
common_args = sign.intersection(kwargs.keys())
return {key: kwargs[key] for key in common_args} # filtered kwargs
def build_env(env_class=None, env_args=None):
if env_class.__module__ == 'gymnasium.envs.registration': # special rule
env = env_class(id=env_args['env_name'])
else:
env = env_class(**kwargs_filter(env_class.__init__, env_args.copy()))
for attr_str in ('env_name', 'state_dim', 'action_dim', 'if_discrete'):
setattr(env, attr_str, env_args[attr_str])
return env