diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 913b936d..14bae670 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -13,7 +13,7 @@ share_policy_params: True prefer_continuous_actions: True # Discount factor -gamma: 0.99 +gamma: 0.9 # Learning rate lr: 0.00005 # Clips grad norm if true and clips grad value if false @@ -35,6 +35,8 @@ exploration_eps_end: 0.01 # Number of frames collected and each experiment iteration collected_frames_per_batch: 6000 +# Number of initial collection batches containing random interactions +init_random_batches: 0 # Number of environments used for collection # If the environment is vectorized, this will be the number of batched environments. # Otherwise this batching will be simulated and each env will be run sequentially. @@ -73,5 +75,5 @@ create_json: True save_folder: null # Absolute path to a checkpoint file where the experiment was saved. If null the experiment is started fresh. restore_file: null -# Interval for experiment saving in terms of experiment iterations +# Interval for experiment saving in terms of experiment iterations. Set it to 0 to disable checkpointing checkpoint_interval: 50 diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py index 76265527..afc0602f 100644 --- a/benchmarl/environments/common.py +++ b/benchmarl/environments/common.py @@ -11,7 +11,7 @@ from torchrl.data import CompositeSpec from torchrl.envs import EnvBase -from benchmarl.utils import read_yaml_config +from benchmarl.utils import DEVICE_TYPING, read_yaml_config def _load_config(name: str, config: Dict[str, Any]): @@ -57,6 +57,7 @@ def get_env_fun( num_envs: int, continuous_actions: bool, seed: Optional[int], + device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: raise NotImplementedError diff --git a/benchmarl/environments/pettingzoo/common.py b/benchmarl/environments/pettingzoo/common.py index 92502658..48bbbae3 100644 --- a/benchmarl/environments/pettingzoo/common.py +++ b/benchmarl/environments/pettingzoo/common.py @@ -5,6 +5,8 @@ from benchmarl.environments.common import Task +from benchmarl.utils import DEVICE_TYPING + class PettingZooTask(Task): MULTIWALKER = None @@ -15,12 +17,14 @@ def get_env_fun( num_envs: int, continuous_actions: bool, seed: Optional[int], + device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: if self.supports_continuous_actions() and self.supports_discrete_actions(): self.config.update({"continuous_actions": continuous_actions}) return lambda: PettingZooEnv( categorical_actions=True, + device=device, seed=seed, parallel=True, return_state=self.has_state(), diff --git a/benchmarl/environments/smacv2/common.py b/benchmarl/environments/smacv2/common.py index de7a0ad3..9d24015b 100644 --- a/benchmarl/environments/smacv2/common.py +++ b/benchmarl/environments/smacv2/common.py @@ -7,6 +7,7 @@ from torchrl.envs.libs.smacv2 import SMACv2Env from benchmarl.environments.common import Task +from benchmarl.utils import DEVICE_TYPING class Smacv2Task(Task): @@ -17,8 +18,11 @@ def get_env_fun( num_envs: int, continuous_actions: bool, seed: Optional[int], + device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: - return lambda: SMACv2Env(categorical_actions=True, seed=seed, **self.config) + return lambda: SMACv2Env( + categorical_actions=True, seed=seed, device=device, **self.config + ) def supports_continuous_actions(self) -> bool: return False diff --git a/benchmarl/environments/vmas/common.py b/benchmarl/environments/vmas/common.py index 23153964..6931ad9c 100644 --- a/benchmarl/environments/vmas/common.py +++ b/benchmarl/environments/vmas/common.py @@ -5,6 +5,7 @@ from torchrl.envs.libs.vmas import VmasEnv from benchmarl.environments.common import Task +from benchmarl.utils import DEVICE_TYPING class VmasTask(Task): @@ -17,12 +18,14 @@ def get_env_fun( num_envs: int, continuous_actions: bool, seed: Optional[int], + device: DEVICE_TYPING, ) -> Callable[[], EnvBase]: return lambda: VmasEnv( scenario=self.name.lower(), num_envs=num_envs, continuous_actions=continuous_actions, seed=seed, + device=device, categorical_actions=True, **self.config, ) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 83875c19..92d9eeed 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -52,6 +52,7 @@ class ExperimentConfig: exploration_eps_end: float = MISSING collected_frames_per_batch: int = MISSING + init_random_batches: int = MISSING n_envs_per_worker: int = MISSING n_iters: int = MISSING n_optimizer_steps: int = MISSING @@ -106,6 +107,10 @@ def total_frames(self) -> int: def exploration_anneal_frames(self) -> int: return self.total_frames // 3 + @property + def init_random_frames(self) -> int: + return self.init_random_batches * self.collected_frames_per_batch + @staticmethod def get_from_yaml(path: Optional[str] = None): if path is None: @@ -191,6 +196,7 @@ def _setup_task(self): num_envs=self.config.evaluation_episodes, continuous_actions=self.continuous_actions, seed=self.seed, + device=self.config.sampling_device, ) )() env_func = self.model_config.process_env_fun( @@ -198,6 +204,7 @@ def _setup_task(self): num_envs=self.config.n_envs_per_worker, continuous_actions=self.continuous_actions, seed=self.seed, + device=self.config.sampling_device, ) ) @@ -219,7 +226,7 @@ def _setup_task(self): else: self.env_func = lambda: TransformedEnv(env_func(), transform.clone()) - self.test_env = test_env + self.test_env = test_env.to(self.config.sampling_device) def _setup_algorithm(self): self.algorithm = self.algorithm_config.get_algorithm( @@ -248,7 +255,7 @@ def _setup_algorithm(self): } self.optimizers = { group: { - loss_name: torch.optim.Adam(params, lr=self.config.lr, eps=1e-4) + loss_name: torch.optim.Adam(params, lr=self.config.lr, eps=1e-6) for loss_name, params in self.algorithm.get_parameters(group).items() } for group in self.group_map.keys() @@ -270,6 +277,7 @@ def _setup_collector(self): storing_device=self.config.train_device, frames_per_batch=self.config.collected_frames_per_batch, total_frames=self.config.total_frames, + init_random_frames=self.config.init_random_frames, ) def _setup_name(self): @@ -278,6 +286,12 @@ def _setup_name(self): self.environment_name = self.task.env_name().lower() self.task_name = self.task.name.lower() + if self.config.restore_file is not None and self.config.save_folder is not None: + raise ValueError( + "Experiment restore file and save folder have both been specified." + "Do not set a save_folder when you are reloading an experiment as" + "it will by default reloaded into the old folder." + ) if self.config.restore_file is None: if self.config.save_folder is not None: folder_name = Path(self.config.save_folder)