diff --git a/.github/workflows/cpu-tests.yaml b/.github/workflows/cpu-tests.yaml index e4850f63..4a145332 100644 --- a/.github/workflows/cpu-tests.yaml +++ b/.github/workflows/cpu-tests.yaml @@ -41,7 +41,7 @@ jobs: - name: Install packages run: | python -m pip install -U pip - python -m pip install .[atari,test,dev] + python -m pip install -e .[atari,test,dev] - name: Run tests run: | diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..c3670a1c --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +recursive-include sheeprl *.py +recursive-include sheeprl *.yaml +global-exclude *.pyc +global-exclude __pycache__ diff --git a/README.md b/README.md index 79d37639..538efcde 100644 --- a/README.md +++ b/README.md @@ -155,13 +155,13 @@ Now you can use one of the already available algorithms, or create your own. For example, to train a PPO agent on the CartPole environment with only vector-like observations, just run ```bash -python sheeprl.py ppo exp=ppo env=gym env.id=CartPole-v1 +python sheeprl.py exp=ppo env=gym env.id=CartPole-v1 ``` You check all the available algorithms with ```bash -python sheeprl.py --sheeprl_help +python sheeprl/available_agents.py ``` That's all it takes to train an agent with SheepRL! ๐ŸŽ‰ @@ -194,17 +194,17 @@ What you run is the PPO algorithm with the default configuration. But you can al For example, in the default configuration, the number of parallel environments is 4. Let's try to change it to 8 by passing the `--num_envs` argument: ```bash -python sheeprl.py ppo exp=ppo env=gym env.id=CartPole-v1 num_envs=8 +python sheeprl.py exp=ppo env=gym env.id=CartPole-v1 env.num_envs=8 ``` All the available arguments, with their descriptions, are listed in the `sheeprl/config` directory. You can find more information about the hierarchy of configs [here](./howto/run_experiments.md). ### Running with Lightning Fabric -To run the algorithm with Lightning Fabric, you need to call Lightning with its parameters. For example, to run the PPO algorithm with 4 parallel environments on 2 nodes, you can run: +To run the algorithm with Lightning Fabric, you need to specify the Fabric parameters through the CLI. For example, to run the PPO algorithm with 4 parallel environments on 2 nodes, you can run: ```bash -lightning run model --accelerator=cpu --strategy=ddp --devices=2 sheeprl.py ppo exp=ppo env=gym env.id=CartPole-v1 +python sheeprl.py fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 exp=ppo env=gym env.id=CartPole-v1 ``` You can check the available parameters for Lightning Fabric [here](https://lightning.ai/docs/fabric/stable/api/fabric_args.html). diff --git a/howto/configs.md b/howto/configs.md index ed0617f4..20e3bc97 100644 --- a/howto/configs.md +++ b/howto/configs.md @@ -22,9 +22,11 @@ sheeprl/configs โ”‚ โ”œโ”€โ”€ droq.yaml โ”‚ โ”œโ”€โ”€ p2e_dv1.yaml โ”‚ โ”œโ”€โ”€ p2e_dv2.yaml +โ”‚ โ”œโ”€โ”€ ppo_decoupled.yaml โ”‚ โ”œโ”€โ”€ ppo_recurrent.yaml โ”‚ โ”œโ”€โ”€ ppo.yaml โ”‚ โ”œโ”€โ”€ sac_ae.yaml +โ”‚ โ”œโ”€โ”€ sac_decoupled.yaml โ”‚ โ””โ”€โ”€ sac.yaml โ”œโ”€โ”€ buffer โ”‚ โ””โ”€โ”€ default.yaml @@ -47,17 +49,26 @@ sheeprl/configs โ”‚ โ”œโ”€โ”€ dreamer_v1.yaml โ”‚ โ”œโ”€โ”€ dreamer_v2_ms_pacman.yaml โ”‚ โ”œโ”€โ”€ dreamer_v2.yaml +โ”‚ โ”œโ”€โ”€ dreamer_v3_100k_boxing.yaml โ”‚ โ”œโ”€โ”€ dreamer_v3_100k_ms_pacman.yaml +โ”‚ โ”œโ”€โ”€ dreamer_v3_dmc_walker_walk.yaml +โ”‚ โ”œโ”€โ”€ dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml โ”‚ โ”œโ”€โ”€ dreamer_v3_L_doapp.yaml โ”‚ โ”œโ”€โ”€ dreamer_v3_L_navigate.yaml โ”‚ โ”œโ”€โ”€ dreamer_v3.yaml โ”‚ โ”œโ”€โ”€ droq.yaml โ”‚ โ”œโ”€โ”€ p2e_dv1.yaml โ”‚ โ”œโ”€โ”€ p2e_dv2.yaml +โ”‚ โ”œโ”€โ”€ ppo_decoupled.yaml โ”‚ โ”œโ”€โ”€ ppo_recurrent.yaml โ”‚ โ”œโ”€โ”€ ppo.yaml โ”‚ โ”œโ”€โ”€ sac_ae.yaml +โ”‚ โ”œโ”€โ”€ sac_decoupled.yaml โ”‚ โ””โ”€โ”€ sac.yaml +โ”œโ”€โ”€ fabric +โ”‚ โ”œโ”€โ”€ ddp-cpu.yaml +โ”‚ โ”œโ”€โ”€ ddp-cuda.yaml +โ”‚ โ””โ”€โ”€ default.yaml โ”œโ”€โ”€ hydra โ”‚ โ””โ”€โ”€ default.yaml โ”œโ”€โ”€ __init__.py @@ -86,9 +97,10 @@ defaults: - buffer: default.yaml - checkpoint: default.yaml - env: default.yaml - - exp: null - - hydra: default.yaml + - fabric: default.yaml - metric: default.yaml + - hydra: default.yaml + - exp: ??? num_threads: 1 total_steps: ??? @@ -112,10 +124,6 @@ cnn_keys: mlp_keys: encoder: [] decoder: ${mlp_keys.encoder} - -# Buffer -buffer: - memmap: True ``` ### Algorithms @@ -305,7 +313,50 @@ The environment configs can be found under the `sheeprl/configs/env` folders. Sh * [MineRL (v0.4.4)](https://minerl.readthedocs.io/en/v0.4.4/) * [MineDojo (v0.1.0)](https://docs.minedojo.org/) -In this way one can easily try out the overall framework with standard RL environments. +In this way one can easily try out the overall framework with standard RL environments. The `default.yaml` config contains all the environment parameters shared by (possibly) all the environments: + +```yaml +id: ??? +num_envs: 4 +frame_stack: 1 +sync_env: False +screen_size: 64 +action_repeat: 1 +grayscale: False +clip_rewards: False +capture_video: True +frame_stack_dilation: 1 +max_episode_steps: null +reward_as_observation: False +``` + +Every custom environment must then "inherit" from this default config, override the particular parameters and define the the `wrapper` field, which is the one that will be directly instantiated at runtime. The `wrapper` field must define all the specific parameters to be passed to the `_target_` function when the wrapper will be instantiated. Take for example the `atari.yaml` config: + +```yaml +defaults: + - default + - _self_ + +# Override from `default` config +action_repeat: 4 +id: PongNoFrameskip-v4 +max_episode_steps: 27000 + +# Wrapper to be instantiated +wrapper: + _target_: gymnasium.wrappers.AtariPreprocessing # https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.AtariPreprocessing + env: + _target_: gymnasium.make + id: ${env.id} + render_mode: rgb_array + noop_max: 30 + terminal_on_life_loss: False + frame_skip: ${env.action_repeat} + screen_size: ${env.screen_size} + grayscale_obs: ${env.grayscale} + scale_obs: False + grayscale_newaxis: True +``` > **Warning** > @@ -362,9 +413,13 @@ algo: Given this config, one can easily run an experiment to test the Dreamer-V3 algorithm on the Ms-PacMan environment with the following simple CLI command: ```bash -lightning run model sheeprl.py dreamer_v3 exp=dreamer_v3_100k_ms_pacman +python sheeprl.py exp=dreamer_v3_100k_ms_pacman ``` +### Fabric + +These configurations control the parameters to be passed to the [Fabric object](https://lightning.ai/docs/fabric/stable/api/generated/lightning.fabric.fabric.Fabric.html#lightning.fabric.fabric.Fabric). With those one can control whether to run the experiments on multiple devices, on which accelerator and with thich precision. For more information please have a look to the [Lightning documentation page](https://lightning.ai/docs/fabric/stable/api/fabric_args.html#). + ### Hydra These configuration file manages where and how to create folders or subfolders for experiments. For more information please visit the [hydra documentation](https://hydra.cc/docs/configure_hydra/intro/). Our default hydra config is the following: @@ -387,7 +442,6 @@ log_every: 5000 sync_on_compute: False ``` - ### Optimizer Each optimizer file defines how we initialize the training optimizer with their parameters. For a better understanding of PyTorch optimizers, one should have a look at it at [https://pytorch.org/docs/stable/optim.html](https://pytorch.org/docs/stable/optim.html). An example config is the following: diff --git a/howto/learn_in_atari.md b/howto/learn_in_atari.md index 77931a21..b4eb2d2e 100644 --- a/howto/learn_in_atari.md +++ b/howto/learn_in_atari.md @@ -8,29 +8,6 @@ pip install .[atari] For more information: https://gymnasium.farama.org/environments/atari/ ## Train your agent -First you need to select which agent you want to train. The list of the trainable agent can be retrieved as follows: - -```bash -Usage: sheeprl.py [OPTIONS] COMMAND [ARGS]... - - SheepRL zero-code command line utility. - -Options: - --sheeprl_help Show this message and exit. - -Commands: - dreamer_v1 - dreamer_v2 - droq - p2e_dv1 - p2e_dv2 - ppo - ppo_decoupled - ppo_recurrent - sac - sac_ae - sac_decoupled -``` It is important to remind that not all the algorithms can work with images, so it is necessary to check the first table in the [README](../README.md) and select a proper algorithm. The list of selectable algorithms is given below: @@ -46,5 +23,5 @@ The list of selectable algorithms is given below: Once you have chosen the algorithm you want to train, you can start the train, for instance, of the ppo agent by running: ```bash -lightning run model --accelerator=cpu --strategy=ddp --devices=2 sheeprl.py ppo exp=ppo env=atari env.id=PongNoFrameskip-v4 cnn_keys.encoder=[rgb] +python sheeprl.py exp=ppo env=atari env.id=PongNoFrameskip-v4 cnn_keys.encoder=[rgb] fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 ``` \ No newline at end of file diff --git a/howto/learn_in_diambra.md b/howto/learn_in_diambra.md index ad77cb87..a4313960 100644 --- a/howto/learn_in_diambra.md +++ b/howto/learn_in_diambra.md @@ -54,14 +54,14 @@ The observation space is slightly modified to be compatible with our algorithms, ## Multi-environments / Distributed training In order to train your agent with multiple environments or to perform a distributed training, you have to specify to the `diambra run` command the number of environments you want to instantiate (through the `-s` cli argument). So, you have to multiply the number of environments per single process and the number of processes you want to launch (the number of *player* processes for decoupled algorithms). Thus, in case of coupled algorithm (e.g., `dreamer_v2`), if you want distribute your training among $2$ processes each one containing $4$ environments, the total number of environments will be: $2 \cdot 4 = 8$. The command will be: ```bash -diambra run -s=8 lightning run model --devices=2 sheeprl.py dreamer_v3 exp=dreamer_v3 env=diambra env.id=doapp num_envs=4 env.sync_env=True cnn_keys.encoder=[frame] +diambra run -s=8 python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.num_envs=4 env.sync_env=True cnn_keys.encoder=[frame] fabric.devices=2 ``` ## Args The IDs of the DIAMBRA environments are specified [here](https://docs.diambra.ai/envs/games/). To train your agent on a DIAMBRA environment you have to select the diambra configs with the argument `env=diambra`, then set the `env.id` argument to the environment ID, e.g., to train your agent on the *Dead Or Alive ++* game, you have to set the `env.id` argument to `doapp` (i.e., `env.id=doapp`). ```bash -diambra run -s=4 lightning run model sheeprl.py dreamer_v3 exp=dreamer_v3 env=diambra env.id=doapp num_envs=4 +diambra run -s=4 python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.num_envs=4 ``` Another possibility is to create a new config file in the `sheeprl/configs/exp` folder, where you specify all the configs you want to use in your experiment. An example of custom configuration file is available [here](../sheeprl/configs/exp/dreamer_v3_L_doapp.yaml). @@ -94,7 +94,7 @@ env: Now, to run your experiment, you have to execute the following command: ```bash -diambra run -s=4 lightning run model sheeprl.py dreamer_v3 exp=custom_exp num_envs=4 +diambra run -s=4 python sheeprl.py exp=custom_exp env.num_envs=4 ``` > **Note** @@ -118,5 +118,5 @@ diambra run -s=4 lightning run model sheeprl.py dreamer_v3 exp=custom_exp num_en ## Headless machines If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: -1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the train command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on an headless machine, you need to run the following command: `xvfb-run diambra run lightning run model --devices=1 sheeprl.py dreamer_v3 exp=dreamer_v3 env=diambra env.id=doapp env.sync_env=True num_envs=1 cnn_keys.encoder=[frame]` +1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the train command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on an headless machine, you need to run the following command: `xvfb-run diambra run python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.sync_env=True env.num_envs=1 cnn_keys.encoder=[frame] fabric.devices=1` 2. Exploit the [PyVirtualDisplay](https://github.com/ponty/PyVirtualDisplay) package. \ No newline at end of file diff --git a/howto/learn_in_dmc.md b/howto/learn_in_dmc.md index cce60dd7..a417116c 100644 --- a/howto/learn_in_dmc.md +++ b/howto/learn_in_dmc.md @@ -19,12 +19,12 @@ For more information: [https://github.com/deepmind/dm_control](https://github.co In order to train your agents on the [MuJoCo environments](https://gymnasium.farama.org/environments/mujoco/) provided by Gymnasium, it is sufficient to select the *GYM* environment (`env=gym`) and set the `env.id` to the name of the environment you want to use. For instance, `"Walker2d-v4"` if you want to train your agent on the *walker walk* environment. ```bash -lightning run model sheeprl.py dreamer_v3 exp=dreamer_v3 env=gym env.id=Walker2d-v4 cnn_keys.encoder=[rgb] +python sheeprl.py exp=dreamer_v3 env=gym env.id=Walker2d-v4 cnn_keys.encoder=[rgb] ``` ## DeepMind Control In order to train your agents on the [DeepMind control suite](https://github.com/deepmind/dm_control/blob/main/dm_control/suite/README.md), you have to select the *DMC* environment (`env=dmc`) and to set the id of the environment you want to use. A list of the available environments can be found [here](https://arxiv.org/abs/1801.00690). For instance, if you want to train your agent on the *walker walk* environment, you need to set the `env.id` to `"walker_walk"`. ```bash -lightning run model sheeprl.py dreamer_v3 exp=dreamer_v3 env=dmc env.id=walker_walk cnn_keys.encoder=[rgb] +python sheeprl.py exp=dreamer_v3 env=dmc env.id=walker_walk cnn_keys.encoder=[rgb] ``` \ No newline at end of file diff --git a/howto/learn_in_minedojo.md b/howto/learn_in_minedojo.md index 6ebadfba..8d3dca54 100644 --- a/howto/learn_in_minedojo.md +++ b/howto/learn_in_minedojo.md @@ -29,7 +29,7 @@ It is possible to train your agents on all the tasks provided by MineDojo. You n For instance, you can use the following command to select the MineDojo open-ended environment. ```bash -lightning run model sheeprl.py p2e_dv2 exp=p2e_dv2 env=minedojo env.id=open-ened algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor cnn_keys.encoder=[rgb] +python sheeprl.py exp=p2e_dv2 env=minedojo env.id=open-ened algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor cnn_keys.encoder=[rgb] ``` ### Observation Space @@ -67,5 +67,5 @@ For more information about the MineDojo action space, check [here](https://docs. ## Headless machines If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: -1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the train command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on an headless machine, you need to run the following command: `xvfb-run lightning run model --devices=1 sheeprl.py p2e_dv2 exp=p2e_dv2 env=minedojo env.id=open-ended cnn_keys.encoder=[rgb] algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor`, or `MINEDOJO_HEADLESS=1 lightning run model --devices=1 sheeprl.py p2e_dv2 exp=p2e_dv2 env=minedojo env.id=open-ended cnn_keys.encoder=[rgb] algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor`. +1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the train command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on an headless machine, you need to run the following command: `xvfb-run python sheeprl.py exp=p2e_dv2 fabric.devices=1 env=minedojo env.id=open-ended cnn_keys.encoder=[rgb] algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor`, or `MINEDOJO_HEADLESS=1 python sheeprl.py exp=p2e_dv2 fabric.devices=1 env=minedojo env.id=open-ended cnn_keys.encoder=[rgb] algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor`. 2. Exploit the [PyVirtualDisplay](https://github.com/ponty/PyVirtualDisplay) package. \ No newline at end of file diff --git a/howto/learn_in_minerl.md b/howto/learn_in_minerl.md index 3df0790c..dfea41cd 100644 --- a/howto/learn_in_minerl.md +++ b/howto/learn_in_minerl.md @@ -54,5 +54,5 @@ Finally we added sticky action for the `jump` and `attack` actions. You can set ## Headless machines If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: -1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the train command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on an headless machine, you need to run the following command: `xvfb-run lightning run model --devices=1 sheeprl.py dreamer_v3 exp=dreamer_v3 env=minerl env.id=custom_navigate cnn_keys.encoder=[rgb]`. +1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the train command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on an headless machine, you need to run the following command: `xvfb-run python sheeprl.py exp=dreamer_v3 fabric.devices=1 env=minerl env.id=custom_navigate cnn_keys.encoder=[rgb]`. 2. Exploit the [PyVirtualDisplay](https://github.com/ponty/PyVirtualDisplay) package. \ No newline at end of file diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index d8f9e4ef..ce13ee78 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -1,5 +1,5 @@ # Register a new algorithm -Suppose that we want to add a new SoTA algorithm to sheeprl called `sota`, so that we can train an agent simply with `python sheeprl.py sota exp=... env=... env.id=...` or accelerated by fabric with `lightning run model sheeprl.py sota exp=... env=... env.id=...`. +Suppose that we want to add a new SoTA algorithm to sheeprl called `sota`, so that we can train an agent simply with `python sheeprl.py exp=sota env=... env.id=...`. We start from creating a new folder called `sota` under `./sheeprl/algos/`, containing the following files: @@ -47,7 +47,7 @@ from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict, make_tensordict from tensordict.tensordict import TensorDictBase from torch.optim import Adam -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.sota.loss import loss1, loss2 from sheeprl.algos.sota.utils import test @@ -57,6 +57,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger +from sheeprl.utils.timer import timer def train( @@ -80,13 +81,8 @@ def train( aggregator.update("Loss/loss2", l2.detach()) -@register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - # Initialize Fabric - fabric = Fabric() - if not _is_using_cli(): - fabric.launch() +@register_algorithm(decoupled=False) +def sota_main(fabric: Fabric, cfg: DictConfig): rank = fabric.global_rank world_size = fabric.world_size device = fabric.device @@ -94,7 +90,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "sota") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -116,6 +112,9 @@ def main(cfg: DictConfig): ) # Create the agent model: this should be a torch.nn.Module to be acceleratesd with Fabric + # Given that the environment has been created with the `make_dict_env` method, the agent + # forward method must accept as input a dictionary like {"obs1_name": obs1, "obs2_name": obs2, ...}. + # The agent should be able to process both image and vector-like observations. agent = ... # Define the agent and the optimizer and setup them with Fabric @@ -140,17 +139,31 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - global_step = 0 - start_time = time.perf_counter() - single_global_rollout = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.total_steps // single_global_rollout if not cfg.dry_run else 1 - - # Linear learning rate scheduler - if cfg.algo.anneal_lr: - from torch.optim.lr_scheduler import PolynomialLR - - scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) + last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 + last_checkpoint = 0 + policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) + num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + + # Warning for log and checkpoint every + if cfg.metric.log_every % policy_steps_per_update != 0: + warnings.warn( + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"policy_steps_per_update value ({policy_steps_per_update}), so " + "the metrics will be logged at the nearest greater multiple of the " + "policy_steps_per_update value." + ) + if cfg.checkpoint.every % policy_steps_per_update != 0: + warnings.warn( + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"policy_steps_per_update value ({policy_steps_per_update}), so " + "the checkpoint will be saved at the nearest greater multiple of the " + "policy_steps_per_update value." + ) + # Get the first environment observation and start the optimization o = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] next_obs = {} for k in o.keys(): @@ -162,20 +175,23 @@ def main(cfg: DictConfig): torch_obs = torch_obs.float() step_data[k] = torch_obs next_obs[k] = torch_obs - next_done = torch.zeros(cfg.env.num_envs, 1, dtype=torch.float32) # [N_envs, 1] + next_done = torch.zeros(cfg.env.num_envs, 1, dtype=torch.float32).to(fabric.device) # [N_envs, 1] for update in range(1, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): - global_step += cfg.env.num_envs * world_size + policy_step += cfg.env.num_envs * world_size - with torch.no_grad(): - # Sample an action given the observation received by the environment - # This calls the `forward` method of the PyTorch module, escaping from Fabric - # because we don't want this to be a synchronization point - action = agent.module(next_obs) + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with torch.no_grad(): + # Sample an action given the observation received by the environment + # This calls the `forward` method of the PyTorch module, escaping from Fabric + # because we don't want this to be a synchronization point + action = agent.module(next_obs) - # Single environment step - o, reward, done, truncated, info = envs.step(action.cpu().numpy().reshape(envs.action_space.shape)) + # Single environment step + o, reward, done, truncated, info = envs.step(action.cpu().numpy().reshape(envs.action_space.shape)) with device: rewards = torch.tensor(reward).view(cfg.env.num_envs, -1) # [N_envs, 1] @@ -205,13 +221,13 @@ def main(cfg: DictConfig): next_done = done if "final_info" in info: - for i, agent_final_info in enumerate(info["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Flatten the batch local_data = rb.buffer.view(-1) @@ -220,17 +236,49 @@ def main(cfg: DictConfig): train(fabric, agent, optimizer, local_data, aggregator, cfg) # Log metrics - metrics_dict = aggregator.compute() - fabric.log_dict(metrics_dict, global_step) - aggregator.reset() + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) + aggregator.reset() + + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + + # Checkpoint model + if ( + (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) + or cfg.dry_run + or update == num_updates + ): + last_checkpoint = policy_step + state = { + "agent": agent.state_dict(), + "optimizer": optimizer.state_dict(), + "update_step": update, + } + ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") + fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) envs.close() if fabric.is_global_zero: test(actor.module, envs, fabric, cfg) - - -if __name__ == "__main__": - main() ``` ## Config files @@ -265,26 +313,19 @@ defaults: - /optim@optimizer: adam - _self_ -name: sota - -anneal_lr: False -gamma: 0.99 -gae_lambda: 0.95 -update_epochs: 10 -loss_reduction: mean -normalize_advantages: False -clip_coef: 0.2 -anneal_clip_coef: False -clip_vloss: False -ent_coef: 0.0 -anneal_ent_coef: False -vf_coef: 1.0 +name: sota # This must be set! -dense_units: 64 +# Algorithm-related paramters +... + +# Agent model parameters +# This is just an example where we suppose we have an Actor-Critic agent +# with both an MLP and a CNN encoder mlp_layers: 2 -dense_act: torch.nn.Tanh +dense_units: 64 layer_norm: False max_grad_norm: 0.0 +dense_act: torch.nn.Tanh encoder: cnn_features_dim: 512 mlp_features_dim: 64 @@ -303,6 +344,7 @@ critic: dense_act: ${algo.dense_act} layer_norm: ${algo.layer_norm} +# Override parameters coming from `adam.yaml` config optimizer: lr: 1e-3 eps: 1e-4 @@ -319,7 +361,11 @@ defaults: - /optim@encoder.optimizer: adam - /optim@actor.optimizer: adam ``` -Will add two optimizers, one accesible with `algo.encoder.optimizer`, the other with `algo.actor.optimizer`. +will add two optimizers, one accesible with `algo.encoder.optimizer`, the other with `algo.actor.optimizer`. + +> **Note** +> +> The field `algo.name` **must** be set and **must** be equal to the name of the file.py, found under the `sheeprl/algos/sota` folder, where the implementation of the algorithm is defined. For example, if your implementation is defined in a python file named `my_sota.py`, i.e. `sheeprl/algos/sota/my_sota.py`, then `algo.name="my_sota"` #### Experiment config In the second file you have to specify all the elements you want in your experiment and you can override all the parameters you want. @@ -333,11 +379,10 @@ defaults: - override /env: atari - _self_ -per_rank_batch_size: 64 total_steps: 65536 +per_rank_batch_size: 64 buffer: share_data: False -rollout_steps: 128 # override environment id env: @@ -352,45 +397,62 @@ With `override /algo: sota` in `defaults` you are specifing you want to use the To let the `register_algorithm` decorator add our new `sota` algorithm to the available algorithms registry we need to import it in `./sheeprl/__init__.py`: ```diff -from dotenv import load_dotenv +import os -from sheeprl.algos.ppo import ppo, ppo_decoupled -from sheeprl.algos.ppo_recurrent import ppo_recurrent -from sheeprl.algos.sac import sac, sac_decoupled -from sheeprl.algos.sac_ae import sac_ae -+from sheeprl.algos.sota import sota +ROOT_DIR = os.path.dirname(__file__) -try: - from sheeprl.algos.ppo import ppo_atari -except ModuleNotFoundError: - pass +from dotenv import load_dotenv load_dotenv() + +from sheeprl.utils.imports import _IS_TORCH_GREATER_EQUAL_2_0 + +if not _IS_TORCH_GREATER_EQUAL_2_0: + raise ModuleNotFoundError(_IS_TORCH_GREATER_EQUAL_2_0) + +# Needed because MineRL 0.4.4 is not compatible with the latest version of numpy +import numpy as np + +from sheeprl.algos.dreamer_v1 import dreamer_v1 as dreamer_v1 +from sheeprl.algos.dreamer_v2 import dreamer_v2 as dreamer_v2 +from sheeprl.algos.dreamer_v3 import dreamer_v3 as dreamer_v3 +from sheeprl.algos.droq import droq as droq +from sheeprl.algos.p2e_dv1 import p2e_dv1 as p2e_dv1 +from sheeprl.algos.p2e_dv2 import p2e_dv2 as p2e_dv2 +from sheeprl.algos.ppo import ppo as ppo +from sheeprl.algos.ppo import ppo_decoupled as ppo_decoupled +from sheeprl.algos.ppo_recurrent import ppo_recurrent as ppo_recurrent +from sheeprl.algos.sac import sac as sac +from sheeprl.algos.sac import sac_decoupled as sac_decoupled +from sheeprl.algos.sac_ae import sac_ae as sac_ae ++from sheeprl.algos.sota import sota as sota + +np.float = np.float32 +np.int = np.int64 +np.bool = bool + +__version__ = "0.3.2" ``` -After doing that, when we run `python sheeprl.py` we should see `sota` under the `Commands` section: +Then if you run `python sheeprl/available_agents.py` you should see that `sota` appears in the list of all the available agents: ```bash -(sheeprl) โžœ sheeprl git:(main) โœ— python sheeprl.py -Usage: sheeprl.py [OPTIONS] COMMAND [ARGS]... - - SheepRL zero-code command line utility. - -Options: - --sheeprl_help Show this message and exit. - -Commands: - dreamer_v1 - dreamer_v2 - dreamer_v3 - droq - p2e_dv1 - p2e_dv2 - ppo - ppo_decoupled - ppo_recurrent - sac - sac_ae - sac_decoupled - sota +SheepRL Agents +โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“ +โ”ƒ Module โ”ƒ Algorithm โ”ƒ Entrypoint โ”ƒ Decoupled โ”ƒ +โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ +โ”‚ sheeprl.algos.dreamer_v1 โ”‚ dreamer_v1 โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.dreamer_v2 โ”‚ dreamer_v2 โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.dreamer_v3 โ”‚ dreamer_v3 โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.sac โ”‚ sac โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.sac โ”‚ sac_decoupled โ”‚ main โ”‚ True โ”‚ +โ”‚ sheeprl.algos.droq โ”‚ droq โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.p2e_dv1 โ”‚ p2e_dv1 โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.p2e_dv2 โ”‚ p2e_dv2 โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.ppo โ”‚ ppo โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.ppo โ”‚ ppo_decoupled โ”‚ main โ”‚ True โ”‚ +โ”‚ sheeprl.algos.ppo_recurrent โ”‚ ppo_recurrent โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.sac_ae โ”‚ sac_ae โ”‚ main โ”‚ False โ”‚ +โ”‚ sheeprl.algos.sota โ”‚ sota โ”‚ sota_main โ”‚ False โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ ``` \ No newline at end of file diff --git a/howto/run_experiments.md b/howto/run_experiments.md index c1daa649..21a7fe16 100644 --- a/howto/run_experiments.md +++ b/howto/run_experiments.md @@ -8,8 +8,8 @@ In this document we give the user some advices to execute its experiments. Now that you are familiar with [hydra](https://hydra.cc/docs/intro/) and the organization of the configs of this repository, we can introduce few constraints to launch experiments: -1. When you launch an experiment you **must** specify the command of the agent you want to train: `python sheeprl.py ...`. The list of the available commands can be retrieved with the following command: `python sheeprl.py --sheeprl_help` -2. Then you have to specify the hyper-parameters of your experiment: you can override the hyper-parameters by specifing them as cli arguments (e.g., `env=dmc env.id=walker_walk algo=dreamer_v3 env.action_repeat=2 ...`) or you can write your custom experiment file (you must put it in the `./sheeprl/configs/exp` folder) and call your script with the command `python sheeprl.py exp=custom_experiment` (the last option is recommended). There are some available examples, just check the [exp folder](../sheeprl/configs/exp/). +1. When you launch an experiment you **must** specify the experiment config of the agent you want to train: `python sheeprl.py exp=...`. The list of the available experiment configs can be retrieved with the following command: `python sheeprl.py --help` +2. Then you have to specify the hyper-parameters of your experiment: you can override the hyper-parameters by specifing them as cli arguments (e.g., `exp=dreamer_v3 algo=dreamer_v3 env=dmc env.id=walker_walk env.action_repeat=2 ...`) or you can write your custom experiment file (you must put it in the `./sheeprl/configs/exp` folder) and call your script with the command `python sheeprl.py exp=custom_experiment` (the last option is recommended). There are some available examples, just check the [exp folder](../sheeprl/configs/exp/). 3. You **cannot mix the agent command with the configs of another algorithm**, this might raise an error or create anomalous behaviors. So if you want to train the `dreamer_v3` agent, be sure to select the correct algorithm configuration (in our case `algo=dreamer_v3`) -4. To change the optimizer of an algorithm through the CLI you must do the following: suppose that you want to run an experiment with Dreamer-V3 and want to change the world model optimizer from Adam (default in the `sheeprl/configs/algo/dreamer_v3.yaml` config) with SGD, then in the CLI you must type `python sheeprl.py dreamer_v3 algo=dreamer_v3 ... optim@algo.world_model.optimizer=sgd`, where `optim@algo.world_model.optimizer=sgd` means that the the `optimizer` field of the `world_model` of the `algo` config choosen (the dreamer_v3.yaml one) will be equal to the config `sgd.yaml` found under the `sheeprl/configs/optim` folder +4. To change the optimizer of an algorithm through the CLI you must do the following: suppose that you want to run an experiment with Dreamer-V3 and want to change the world model optimizer from Adam (default in the `sheeprl/configs/algo/dreamer_v3.yaml` config) with SGD, then in the CLI you must type `python sheeprl.py algo=dreamer_v3 ... optim@algo.world_model.optimizer=sgd`, where `optim@algo.world_model.optimizer=sgd` means that the the `optimizer` field of the `world_model` of the `algo` config choosen (the dreamer_v3.yaml one) will be equal to the config `sgd.yaml` found under the `sheeprl/configs/optim` folder diff --git a/howto/select_observations.md b/howto/select_observations.md index e84cbe08..52088272 100644 --- a/howto/select_observations.md +++ b/howto/select_observations.md @@ -28,7 +28,7 @@ You just need to pass the `mlp_keys` and `cnn_keys` of the encoder and the decod For instance, to train the ppo algorithm on the *walker walk* task provided by *DMC* using image observations and only the `orientations` and `velocity` as vector observation, you have to run the following command: ```bash -lightning run model sheeprl.py ppo exp=ppo env=dmc env.id=walker_walk cnn_keys.encoder=[rgb] mlp_keys.encoder=[orientations,velocity] +python sheeprl.py exp=ppo env=dmc env.id=walker_walk cnn_keys.encoder=[rgb] mlp_keys.encoder=[orientations,velocity] ``` > **Note** @@ -39,13 +39,13 @@ It is important to know the observations the environment provides, for instance, > **Note** > > For some environments provided by gymnasium, e.g. `LunarLander-v2` or `CartPole-v1`, only vector observations are returned, but it is possible to extract the image observation from the render. To do this, it is sufficient to specify the `rgb` key to the `cnn_keys` args: -> `lightning run model sheeprl.py ppo cnn_keys.encoder=[rgb]` +> `python sheeprl.py cnn_keys.encoder=[rgb]` #### Frame Stack For image observations it is possible to stack the last $n$ observations with the argument `frame_stack`. All the observations specified in the `cnn_keys` argument are stacked. ```bash -lightning run model sheeprl.py ppo --env_id=dmc_walker_walk "cnn_keys.encoder=[rgb]" env.frame_stack=3 +python sheeprl.py exp=... env=dmc "cnn_keys.encoder=[rgb]" env.frame_stack=3 ``` #### How to choose the correct keys @@ -69,7 +69,7 @@ You can specify different observations for the encoder and the decoder, but ther You can specify the *mlp* and *cnn* keys of the decoder as follows: ```bash -lightning run model sheeprl.py dreamer_v3 exp=dreamer_v3 env=minerl env.id=custom_navigate mlp_keys.encoder=[life_stats,inventory,max_inventory] mlp_keys.decoder=[life_stats,inventory] +python sheeprl.py exp=dreamer_v3 env=minerl env.id=custom_navigate mlp_keys.encoder=[life_stats,inventory,max_inventory] mlp_keys.decoder=[life_stats,inventory] ``` ### Vector observations algorithms diff --git a/howto/work_with_steps.md b/howto/work_with_steps.md index b17b9d0e..cad4f885 100644 --- a/howto/work_with_steps.md +++ b/howto/work_with_steps.md @@ -12,7 +12,7 @@ We start from the concept of *policy step*: a policy step is the particular step Now that we have introduced the concept of policy step, it is necessary to clarify some aspects: 1. When there are multiple parallel environments, the policy step is proportional to the number of parallel environments. E.g., if there are $m$ environments, then the actor has to choose $m$ actions and each environment performs an environment step: this means that $\bold{m}$ **policy steps** are performed. -2. When there are multiple parallel processes (i.e. the script has been run with `lightning run model --devices>=2 ...`), the policy step it is proportional to the number of parallel processes. E.g., let us assume that there are $n$ processes each one containing one single environment: the $n$ actors select an action and a (per-process) step in the environment is performed. In this case $\bold{n}$ **policy steps** are performed. +2. When there are multiple parallel processes (i.e. the script has been run with `lpython sheeprl fabric.devices>=2 ...`), the policy step it is proportional to the number of parallel processes. E.g., let us assume that there are $n$ processes each one containing one single environment: the $n$ actors select an action and a (per-process) step in the environment is performed. In this case $\bold{n}$ **policy steps** are performed. In general, if we have $n$ parallel processes, each one with $m$ independent environments, the policy step increases **globally** by $n \cdot m$ at each iteration. diff --git a/sheeprl/algos/dreamer_v1/README.md b/sheeprl/algos/dreamer_v1/README.md index 1640fe69..21b81379 100644 --- a/sheeprl/algos/dreamer_v1/README.md +++ b/sheeprl/algos/dreamer_v1/README.md @@ -82,7 +82,7 @@ For more information see the official documentation of [Gymnasium Atari environm ## DMC environments It is possible to use the environments provided by the [DeepMind Control suite](https://www.deepmind.com/open-source/deepmind-control-suite). To use such environments it is necessary to specify "dmc" and the name of the environment in the `env` and `env.id` hyper-parameters respectively, e.g., `env=dmc env.id=walker_walk` will create an instance of the walker walk environment. For more information about all the environments, check their [paper](https://arxiv.org/abs/1801.00690). -When running DreamerV1 in a DMC environment on a server (or a PC without a video terminal) it could be necessary to add two variables to the command to launch the script: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa `. For instance, to run walker walk with DreamerV1 on two gpus (0 and 1) it is necessary to runthe following command: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa lightning run model --devices=2 --accelerator=gpu sheeprl.py dreamer_v1 exp=dreamer_v1 env=dmc env.id=walker_walk env.action_repeat=2 env.capture_video=True checkpoint.every=100000 cnn_keys.encoder=[rgb]`. +When running DreamerV1 in a DMC environment on a server (or a PC without a video terminal) it could be necessary to add two variables to the command to launch the script: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa `. For instance, to run walker walk with DreamerV1 on two gpus (0 and 1) it is necessary to runthe following command: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa python sheeprl.py exp=dreamer_v1 fabric.devices=2 fabric.accelerator=gpu env=dmc env.id=walker_walk env.action_repeat=2 env.capture_video=True checkpoint.every=100000 cnn_keys.encoder=[rgb]`. Other possibitities for the variable `MUJOCO_GL` are: `GLFW` for rendering to an X11 window or and `EGL` for hardware accelerated headless. (For more information, click [here](https://mujoco.readthedocs.io/en/stable/programming/index.html#using-opengl)). Moreover, it could be necessary to decomment two rows in the `sheeprl.algos.dreamer_v1.dreamer_v1.py` file. diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 7ea3e16f..d02cca29 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -10,7 +10,6 @@ import torch import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict @@ -23,13 +22,12 @@ from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v2.utils import test from sheeprl.data.buffers import AsyncReplayBuffer -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import compute_lambda_values, polynomial_decay, print_config +from sheeprl.utils.utils import compute_lambda_values, polynomial_decay # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -357,18 +355,11 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - +def main(fabric: Fabric, cfg: DictConfig): # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -388,7 +379,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "dreamer_v1") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -762,7 +753,3 @@ def main(cfg: DictConfig): envs.close() if fabric.is_global_zero: test(player, fabric, cfg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/dreamer_v2/README.md b/sheeprl/algos/dreamer_v2/README.md index 6e25d3cc..b22f2d76 100644 --- a/sheeprl/algos/dreamer_v2/README.md +++ b/sheeprl/algos/dreamer_v2/README.md @@ -111,8 +111,9 @@ For more information see the official documentation of [Gymnasium Atari environm The standard hyperparameters to learn in the Atari environments are: ```bash -lightning run model --devices=1 sheeprl.py dreamer_v2 \ +python sheeprl.py \ exp=dreamer_v2 \ +fabric.devices=1 \ env=atari \ env.id=AssaultNoFrameskip-v0 \ env.capture_video=True \ @@ -146,15 +147,16 @@ buffer.prioritize_ends=True ## DMC environments It is possible to use the environments provided by the [DeepMind Control suite](https://www.deepmind.com/open-source/deepmind-control-suite). To use such environments it is necessary to specify "dmc" and the name of the environment in the `env` and `env.id` hyper-parameters respectively, e.g., `env=dmc env.id=walker_walk` will create an instance of the walker walk environment. For more information about all the environments, check their [paper](https://arxiv.org/abs/1801.00690). -When running DreamerV2 in a DMC environment on a server (or a PC without a video terminal) it could be necessary to add two variables to the command to launch the script: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa `. For instance, to run walker walk with DreamerV2 on two gpus (0 and 1) it is necessary to runthe following command: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa CUDA_VISIBLE_DEVICES="2,3" lightning run model --devices=2 --accelerator=gpu sheeprl.py dreamer_v2 exp=dreamer_v2 env=dmc env.id=walker_walk env.action_repeat=2 env.capture_video=True checkpoint.every=80000 cnn_keys.encoder=[rgb]`. +When running DreamerV2 in a DMC environment on a server (or a PC without a video terminal) it could be necessary to add two variables to the command to launch the script: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa `. For instance, to run walker walk with DreamerV2 on two gpus (0 and 1) it is necessary to runthe following command: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa CUDA_VISIBLE_DEVICES="2,3" python sheeprl.py exp=dreamer_v2 fabric.devices=2 fabric.accelerator=gpu env=dmc env.id=walker_walk env.action_repeat=2 env.capture_video=True checkpoint.every=80000 cnn_keys.encoder=[rgb]`. Other possibitities for the variable `MUJOCO_GL` are: `GLFW` for rendering to an X11 window or and `EGL` for hardware accelerated headless. (For more information, click [here](https://mujoco.readthedocs.io/en/stable/programming/index.html#using-opengl)). Moreover, it could be necessary to decomment two rows in the `sheeprl.algos.dreamer_v1.dreamer_v1.py` file. The standard hyperparameters used for the DMC environment are the following: ```bash -PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa lightning run model --devices=1 sheeprl.py dreamer_v2 \ +PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa python sheeprl.py \ exp=dreamer_v2 \ +fabric.devices=1 \ env=dmc \ env.env.id=dmc_walker_walk \ env.capture_video=True \ diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 2ce31283..34d1d4d3 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -14,7 +14,6 @@ import torch import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from lightning.fabric.wrappers import _FabricModule from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict @@ -29,13 +28,12 @@ from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, print_config +from sheeprl.utils.utils import polynomial_decay # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -377,18 +375,11 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - +def main(fabric: Fabric, cfg: DictConfig): # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -408,7 +399,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "dreamer_v2") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -841,7 +832,3 @@ def main(cfg: DictConfig): envs.close() if fabric.is_global_zero: test(player, fabric, cfg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 75834ff1..37650af7 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -15,7 +15,6 @@ import torch import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from lightning.fabric.wrappers import _FabricModule from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict @@ -31,14 +30,13 @@ from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.envs.wrappers import RestartOnException -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.distribution import MSEDistribution, SymlogDistribution, TwoHotEncodingDistribution from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, print_config +from sheeprl.utils.utils import polynomial_decay # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -330,19 +328,12 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - +def main(fabric: Fabric, cfg: DictConfig): # These arguments cannot be changed cfg.env.frame_stack = -1 if 2 ** int(np.log2(cfg.env.screen_size)) != cfg.env.screen_size: raise ValueError(f"The screen size must be a power of 2, got: {cfg.env.screen_size}") - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -362,7 +353,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "dreamer_v3") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -780,7 +771,3 @@ def main(cfg: DictConfig): envs.close() if fabric.is_global_zero: test(player, fabric, cfg, sample_actions=True) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index b2917e25..119f75af 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -9,7 +9,6 @@ import torch import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict, make_tensordict from torch.optim import Optimizer @@ -22,13 +21,11 @@ from sheeprl.algos.sac.loss import entropy_loss, policy_loss from sheeprl.algos.sac.sac import test from sheeprl.data.buffers import ReplayBuffer -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import print_config def train( @@ -126,14 +123,7 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() +def main(fabric: Fabric, cfg: DictConfig): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -154,7 +144,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "droq") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -404,7 +394,3 @@ def main(cfg: DictConfig): vector_env_idx=0, )() test(agent.actor.module, test_env, fabric, cfg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index eb301f40..83073bbe 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -10,7 +10,6 @@ import torch import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer from lightning.pytorch.utilities.seed import isolate_rng from omegaconf import DictConfig, OmegaConf @@ -27,13 +26,12 @@ from sheeprl.algos.p2e_dv1.agent import build_models from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.models.models import MLP -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import compute_lambda_values, init_weights, polynomial_decay, print_config +from sheeprl.utils.utils import compute_lambda_values, init_weights, polynomial_decay # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -362,18 +360,11 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - +def main(fabric: Fabric, cfg: DictConfig): # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -393,7 +384,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "p2e_dv1") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -849,7 +840,3 @@ def main(cfg: DictConfig): if fabric.is_global_zero: player.actor = actor_task.module test(player, fabric, cfg, "few-shot") - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index c8e10803..5ec5911d 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -10,7 +10,6 @@ import torch import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer from lightning.pytorch.utilities.seed import isolate_rng from omegaconf import DictConfig, OmegaConf @@ -27,13 +26,12 @@ from sheeprl.algos.p2e_dv2.agent import build_models from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer from sheeprl.models.models import MLP -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, print_config +from sheeprl.utils.utils import polynomial_decay # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -464,18 +462,11 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - +def main(fabric: Fabric, cfg: DictConfig): # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -495,7 +486,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "p2e_dv2") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -1038,7 +1029,3 @@ def main(cfg: DictConfig): if fabric.is_global_zero: player.actor = actor_task.module test(player, fabric, cfg, "few-shot") - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 516b1f84..4742d010 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -9,7 +9,6 @@ import numpy as np import torch from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from lightning.fabric.wrappers import _FabricModule from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict, make_tensordict @@ -22,13 +21,12 @@ from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss from sheeprl.algos.ppo.utils import test from sheeprl.data import ReplayBuffer -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay def train( @@ -107,11 +105,8 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - - if "minedojo" in cfg.env.env._target_.lower(): +def main(fabric: Fabric, cfg: DictConfig): + if "minedojo" in cfg.env.wrapper._target_.lower(): raise ValueError( "MineDojo is not currently supported by PPO agent, since it does not take " "into consideration the action masks provided by the environment, but needed " @@ -123,9 +118,6 @@ def main(cfg: DictConfig): initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() rank = fabric.global_rank world_size = fabric.world_size device = fabric.device @@ -146,7 +138,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "ppo") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -458,7 +450,3 @@ def main(cfg: DictConfig): vector_env_idx=0, )() test(agent.module, test_env, fabric, cfg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 8d57cc22..70dbeb92 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -1,9 +1,8 @@ import copy import os import pathlib -import time import warnings -from datetime import datetime, timedelta +from datetime import timedelta import gymnasium as gym import hydra @@ -11,7 +10,6 @@ import torch from lightning.fabric import Fabric from lightning.fabric.fabric import _is_using_cli -from lightning.fabric.loggers import TensorBoardLogger from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy @@ -31,17 +29,15 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay @torch.no_grad() -def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective): - print_config(cfg) - - # Initialize Fabric object - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() +def player( + fabric: Fabric, cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective +): + # Initialize the fabric object + logger = fabric.logger device = fabric.device fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -58,19 +54,6 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co cfg.root_dir = root_dir cfg.run_name = run_name - # Initialize logger - root_dir = ( - os.path.join("logs", "runs", cfg.root_dir) - if cfg.root_dir is not None - else os.path.join("logs", "runs", "ppo_decoupled", datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) - ) - run_name = ( - cfg.run_name if cfg.run_name is not None else f"{cfg.env.id}_{cfg.exp_name}_{cfg.seed}_{int(time.time())}" - ) - logger = TensorBoardLogger(root_dir=root_dir, name=run_name) - fabric._loggers = [logger] - logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) - # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv envs = vectorized_env( @@ -593,16 +576,14 @@ def trainer( @register_algorithm(decoupled=True) -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - devices = os.environ.get("LT_DEVICES", None) - if devices is None or devices == "1": +def main(fabric: Fabric, cfg: DictConfig): + if fabric.world_size == 1: raise RuntimeError( "Please run the script with the number of devices greater than 1: " - "`lightning run model --devices=2 sheeprl.py ...`" + "`python sheeprl.py exp=ppo_decoupled fabric.devices=2 ...`" ) - if "minedojo" in cfg.env.env._target_.lower(): + if "minedojo" in cfg.env.wrapper._target_.lower(): raise ValueError( "MineDojo is not currently supported by PPO agent, since it does not take " "into consideration the action masks provided by the environment, but needed " @@ -639,10 +620,6 @@ def main(cfg: DictConfig): ranks=list(range(1, world_collective.world_size)), timeout=timedelta(days=1) ) if global_rank == 0: - player(cfg, world_collective, player_trainer_collective) + player(fabric, cfg, world_collective, player_trainer_collective) else: trainer(world_collective, player_trainer_collective, optimization_pg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 0dda0853..bdb1ac12 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -12,7 +12,6 @@ import numpy as np import torch from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict from tensordict.tensordict import TensorDictBase, pad_sequence @@ -25,13 +24,12 @@ from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent from sheeprl.algos.ppo_recurrent.utils import test from sheeprl.data import ReplayBuffer -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay def train( @@ -107,20 +105,14 @@ def train( aggregator.update("Loss/entropy_loss", ent_loss.detach()) -@register_algorithm(decoupled=True) -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) +@register_algorithm() +def main(fabric: Fabric, cfg: DictConfig): initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) if cfg.buffer.share_data: warnings.warn("The script has been called with --share-data: with recurrent PPO only gradients are shared") - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() rank = fabric.global_rank world_size = fabric.world_size device = fabric.device @@ -141,7 +133,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "ppo_recurrent") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -437,7 +429,3 @@ def main(cfg: DictConfig): vector_env_idx=0, )() test(agent.module, test_env, fabric, cfg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 4f40a338..725e3fea 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -10,7 +10,6 @@ import numpy as np import torch from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli from lightning.fabric.plugins.collectives.collective import CollectibleGroup from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict, make_tensordict @@ -24,13 +23,11 @@ from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import print_config def train( @@ -81,14 +78,7 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() +def main(fabric: Fabric, cfg: DictConfig): device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -109,7 +99,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "sac") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -399,7 +389,3 @@ def main(cfg: DictConfig): vector_env_idx=0, )() test(agent.actor.module, test_env, fabric, cfg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 7140a73c..23585681 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -2,7 +2,7 @@ import pathlib import time import warnings -from datetime import datetime, timedelta +from datetime import timedelta from math import prod import gymnasium as gym @@ -11,7 +11,6 @@ import torch from lightning.fabric import Fabric from lightning.fabric.fabric import _is_using_cli -from lightning.fabric.loggers import TensorBoardLogger from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy @@ -30,17 +29,13 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import print_config @torch.no_grad() -def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective): - print_config(cfg) - - # Initialize Fabric - fabric = Fabric(callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() +def player( + fabric: Fabric, cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective +): + logger = fabric.logger rank = fabric.global_rank device = fabric.device fabric.seed_everything(cfg.seed) @@ -58,19 +53,6 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co cfg.root_dir = root_dir cfg.run_name = run_name - # Initialize logger - root_dir = ( - os.path.join("logs", "runs", cfg.root_dir) - if cfg.root_dir is not None - else os.path.join("logs", "runs", "sac_decoupled", datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) - ) - run_name = ( - cfg.run_name if cfg.run_name is not None else f"{cfg.env.id}_{cfg.exp_name}_{cfg.seed}_{int(time.time())}" - ) - logger = TensorBoardLogger(root_dir=root_dir, name=run_name) - fabric._loggers = [logger] - logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) - # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv envs = vectorized_env( @@ -505,8 +487,21 @@ def trainer( @register_algorithm(decoupled=True) -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): +def main(fabric: Fabric, cfg: DictConfig): + if fabric.world_size == 1: + raise RuntimeError( + "Please run the script with the number of devices greater than 1: " + "`python sheeprl.py exp=sac_decoupled fabric.devices=2 ...`" + ) + + if "minedojo" in cfg.env.wrapper._target_.lower(): + raise ValueError( + "MineDojo is not currently supported by PPO agent, since it does not take " + "into consideration the action masks provided by the environment, but needed " + "in order to play correctly the game. " + "As an alternative you can use one of the Dreamers' agents." + ) + world_collective = TorchCollective() player_trainer_collective = TorchCollective() world_collective.setup( @@ -519,12 +514,6 @@ def main(cfg: DictConfig): world_collective.create_group(timeout=timedelta(days=1)) global_rank = world_collective.rank - if world_collective.world_size == 1: - raise RuntimeError( - "Please run the script with the number of devices greater than 1: " - "`lightning run model --devices=2 sheeprl.py ...`" - ) - # Create a group between rank-0 (player) and rank-1 (trainer), assigning it to the collective: # used by rank-1 to send metrics to be tracked by the rank-0 at the end of a training episode player_trainer_collective.create_group(ranks=[0, 1], timeout=timedelta(days=1)) @@ -536,10 +525,6 @@ def main(cfg: DictConfig): ranks=list(range(1, world_collective.world_size)), timeout=timedelta(days=1) ) if global_rank == 0: - player(cfg, world_collective, player_trainer_collective) + player(fabric, cfg, world_collective, player_trainer_collective) else: trainer(world_collective, player_trainer_collective, optimization_pg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/algos/sac_ae/README.md b/sheeprl/algos/sac_ae/README.md index 7507dc88..0cda4f0a 100644 --- a/sheeprl/algos/sac_ae/README.md +++ b/sheeprl/algos/sac_ae/README.md @@ -201,7 +201,7 @@ For more information see the official documentation of [Gymnasium Atari environm ## DMC environments It is possible to use the environments provided by the [DeepMind Control suite](https://www.deepmind.com/open-source/deepmind-control-suite). To use such environments it is necessary to specify "dmc" and the name of the environment in the `env` and `env.id` hyper-parameters respectively, e.g., `env=dmc env.id=walker_walk` will create an instance of the walker walk environment. For more information about all the environments, check their [paper](https://arxiv.org/abs/1801.00690). -When running DreamerV1 in a DMC environment on a server (or a PC without a video terminal) it could be necessary to add two variables to the command to launch the script: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa `. For instance, to run walker walk with DreamerV1 on two gpus (0 and 1) it is necessary to runthe following command: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa lightning run model --devices=2 --accelerator=gpu sheeprl.py sac_ae exp=sac_ae env=dmc env.id=walker_walk env.action_repeat=2 env.capture_video=True checkpoint.every=80000 cnn_keys.encoder=[rgb]`. +When running DreamerV1 in a DMC environment on a server (or a PC without a video terminal) it could be necessary to add two variables to the command to launch the script: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa `. For instance, to run walker walk with DreamerV1 on two gpus (0 and 1) it is necessary to runthe following command: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa python sheeprl.py exp=sac_ae fabric.devices=2 fabric.accelerator=gpu env=dmc env.id=walker_walk env.action_repeat=2 env.capture_video=True checkpoint.every=80000 cnn_keys.encoder=[rgb]`. Other possibitities for the variable `MUJOCO_GL` are: `GLFW` for rendering to an X11 window or and `EGL` for hardware accelerated headless. (For more information, click [here](https://mujoco.readthedocs.io/en/stable/programming/index.html#using-opengl)). ## Recommendations diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 4e7c5bcf..dac579b1 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -12,10 +12,7 @@ import torch import torch.nn.functional as F from lightning.fabric import Fabric -from lightning.fabric.accelerators import CUDAAccelerator, TPUAccelerator -from lightning.fabric.fabric import _is_using_cli from lightning.fabric.plugins.collectives.collective import CollectibleGroup -from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy from lightning.fabric.wrappers import _FabricModule from omegaconf import DictConfig, OmegaConf from tensordict import TensorDict, make_tensordict @@ -39,13 +36,11 @@ from sheeprl.algos.sac_ae.utils import preprocess_obs, test_sac_ae from sheeprl.data.buffers import ReplayBuffer from sheeprl.models.models import MultiDecoder, MultiEncoder -from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import print_config def train( @@ -134,11 +129,8 @@ def train( @register_algorithm() -@hydra.main(version_base=None, config_path="../../configs", config_name="config") -def main(cfg: DictConfig): - print_config(cfg) - - if "minedojo" in cfg.env.env._target_.lower(): +def main(fabric: Fabric, cfg: DictConfig): + if "minedojo" in cfg.env.wrapper._target_.lower(): raise ValueError( "MineDojo is not currently supported by SAC-AE agent, since it does not take " "into consideration the action masks provided by the environment, but needed " @@ -149,26 +141,6 @@ def main(cfg: DictConfig): # These arguments cannot be changed cfg.env.screen_size = 64 - # Initialize Fabric - devices = os.environ.get("LT_DEVICES", None) - strategy = os.environ.get("LT_STRATEGY", None) - is_tpu_available = TPUAccelerator.is_available() - if strategy is not None: - warnings.warn( - "You are running the SAC-AE algorithm through the Lightning CLI and you have specified a strategy: " - f"`lightning run model --strategy={strategy}`. This algorithm is run with the " - "`lightning.fabric.strategies.DDPStrategy` strategy, unless a TPU is available." - ) - os.environ.pop("LT_STRATEGY") - if is_tpu_available: - strategy = "auto" - else: - strategy = DDPStrategy(find_unused_parameters=True) - if devices == "1": - strategy = SingleDeviceStrategy(device="cuda:0" if CUDAAccelerator.is_available() else "cpu") - fabric = Fabric(strategy=strategy, callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -189,7 +161,7 @@ def main(cfg: DictConfig): # Create TensorBoardLogger. This will create the logger only on the # rank-0 process - logger, log_dir = create_tensorboard_logger(fabric, cfg, "sac_ae") + logger, log_dir = create_tensorboard_logger(fabric, cfg) if fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) @@ -592,7 +564,3 @@ def main(cfg: DictConfig): if fabric.is_global_zero: test_env = make_dict_env(cfg, cfg.seed, 0, fabric.logger.log_dir, "test", vector_env_idx=0)() test_sac_ae(agent.actor.module, test_env, fabric, cfg) - - -if __name__ == "__main__": - main() diff --git a/sheeprl/available_agents.py b/sheeprl/available_agents.py new file mode 100644 index 00000000..d279e147 --- /dev/null +++ b/sheeprl/available_agents.py @@ -0,0 +1,18 @@ +if __name__ == "__main__": + from rich.console import Console + from rich.table import Table + + from sheeprl.utils.registry import tasks + + table = Table(title="SheepRL Agents") + table.add_column("Module") + table.add_column("Algorithm") + table.add_column("Entrypoint") + table.add_column("Decoupled") + + for module, implementations in tasks.items(): + for algo in implementations: + table.add_row(module, algo["name"], algo["entrypoint"], str(algo["decoupled"])) + + console = Console() + console.print(table) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 30168f7e..2abaffb8 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -1,93 +1,84 @@ -"""Adapted from https://github.com/Lightning-Universe/lightning-flash/blob/master/src/flash/__main__.py""" - -import functools +import datetime import importlib import os +import time import warnings -from contextlib import closing -from typing import Optional -from unittest.mock import patch - -import click -from lightning.fabric.fabric import _is_using_cli -from sheeprl.utils.registry import decoupled_tasks, tasks +import hydra +from lightning import Fabric +from lightning.fabric.accelerators.tpu import TPUAccelerator +from lightning.fabric.loggers.tensorboard import TensorBoardLogger +from lightning.fabric.strategies.ddp import DDPStrategy +from omegaconf import DictConfig, OmegaConf, open_dict -CONTEXT_SETTINGS = dict(help_option_names=["--sheeprl_help"]) +from sheeprl.utils.callback import CheckpointCallback +from sheeprl.utils.registry import tasks +from sheeprl.utils.utils import print_config -@click.group(no_args_is_help=True, add_help_option=True, context_settings=CONTEXT_SETTINGS) -def run(): +@hydra.main(version_base=None, config_path="configs", config_name="config") +def run(cfg: DictConfig): """SheepRL zero-code command line utility.""" - if not _is_using_cli(): - warnings.warn( - "This script was launched without the Lightning CLI. Consider to launch the script with " - "`lightning run model ...` to scale it with Fabric" + if cfg.fabric.strategy == "fsdp": + raise ValueError( + "FSDPStrategy is currently not supported. Please launch the script with another strategy: " + "`python sheeprl.py fabric.strategy=...`" ) + print_config(cfg) - -def register_command(command, task, name: Optional[str] = None): - @run.command( - name if name is not None else command.__name__, - context_settings=dict( - help_option_names=[], - ignore_unknown_options=True, - ), - ) - @click.argument("cli_args", nargs=-1, type=click.UNPROCESSED) - @functools.wraps(command) - def wrapper(cli_args): - additional_args = [f"algo.name={name}"] - if len(list(cli_args)) == 0: - additional_args.append(f"exp={name}") - with patch("sys.argv", [task.__file__] + list(cli_args) + additional_args) as sys_argv_mock: - devices = os.environ.get("LT_DEVICES", None) - strategy = os.environ.get("LT_STRATEGY", None) - if strategy == "fsdp": - raise ValueError( - "FSDPStrategy is currently not supported. Please launch the script with another strategy: " - "`lightning run model --strategy=... sheeprl.py ...`" - ) - if name in decoupled_tasks and strategy is not None: + # Given the algorithm's name, retrieve the module where + # 'cfg.algo.name'.py is contained; from there retrieve the + # `register_algorithm`-decorated entrypoint; + # the entrypoint will be launched by Fabric with `fabric.launch(entrypoint)` + module = None + decoupled = False + entrypoint = None + algo_name = cfg.algo.name + for _module, _algos in tasks.items(): + for _algo in _algos: + if algo_name == _algo["name"]: + module = _module + entrypoint = _algo["entrypoint"] + decoupled = _algo["decoupled"] + break + if module is None: + raise RuntimeError(f"Given the algorithm named `{algo_name}`, no module has been found to be imported.") + if entrypoint is None: + raise RuntimeError( + f"Given the module and algorithm named `{module}` and `{algo_name}` respectively, " + "no entrypoint has been found to be imported." + ) + task = importlib.import_module(f"{module}.{algo_name}") + command = task.__dict__[entrypoint] + if decoupled: + root_dir = ( + os.path.join("logs", "runs", cfg.root_dir) + if cfg.root_dir is not None + else os.path.join("logs", "runs", algo_name, datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) + ) + run_name = ( + cfg.run_name if cfg.run_name is not None else f"{cfg.env.id}_{cfg.exp_name}_{cfg.seed}_{int(time.time())}" + ) + logger = TensorBoardLogger(root_dir=root_dir, name=run_name) + logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) + fabric = Fabric(**cfg.fabric, loggers=logger, callbacks=[CheckpointCallback()]) + else: + if "sac_ae" in module: + strategy = cfg.fabric.strategy + is_tpu_available = TPUAccelerator.is_available() + if strategy is not None: warnings.warn( - "You are running a decoupled algorithm through the Lightning CLI " - "and you have specified a strategy: " - f"`lightning run model --strategy={strategy}`. " - "When a decoupled algorithm is run the default strategy will be " - "a `lightning.fabric.strategies.DDPStrategy`." + "You are running the SAC-AE algorithm you have specified a strategy different than `ddp`: " + f"`python sheeprl.py fabric.strategy={strategy}`. This algorithm is run with the " + "`lightning.fabric.strategies.DDPStrategy` strategy, unless a TPU is available." ) - os.environ.pop("LT_STRATEGY") - if name in decoupled_tasks and not _is_using_cli(): - import torch.distributed.run as torchrun - from torch.distributed.elastic.utils import get_socket_with_port - - sock = get_socket_with_port() - with closing(sock): - master_port = sock.getsockname()[1] - nproc_per_node = "2" if devices is None else devices - torchrun_args = [ - f"--nproc_per_node={nproc_per_node}", - "--nnodes=1", - "--node-rank=0", - "--start-method=spawn", - "--master-addr=localhost", - f"--master-port={master_port}", - ] + sys_argv_mock - torchrun.main(torchrun_args) + if is_tpu_available: + strategy = "auto" else: - if not _is_using_cli() and devices is None: - os.environ["LT_DEVICES"] = "1" - command() - - -for module, algos in tasks.items(): - for algo in algos: - try: - algo_name = algo - task = importlib.import_module(f"{module}.{algo_name}") - - for command in task.__all__: - command = task.__dict__[command] - register_command(command, task, name=algo_name) - except ImportError: - pass + strategy = DDPStrategy(find_unused_parameters=True) + with open_dict(cfg): + cfg.fabric.pop("strategy", None) + fabric = Fabric(**cfg.fabric, strategy=strategy, callbacks=[CheckpointCallback()]) + else: + fabric = Fabric(**cfg.fabric, callbacks=[CheckpointCallback()]) + fabric.launch(command, cfg) diff --git a/sheeprl/configs/algo/default.yaml b/sheeprl/configs/algo/default.yaml index e69de29b..35298376 100644 --- a/sheeprl/configs/algo/default.yaml +++ b/sheeprl/configs/algo/default.yaml @@ -0,0 +1 @@ +name: ??? diff --git a/sheeprl/configs/algo/ppo_decoupled.yaml b/sheeprl/configs/algo/ppo_decoupled.yaml new file mode 100644 index 00000000..ba3c0479 --- /dev/null +++ b/sheeprl/configs/algo/ppo_decoupled.yaml @@ -0,0 +1,6 @@ +defaults: + - ppo + - _self_ + +# Training receipe +name: ppo_decoupled diff --git a/sheeprl/configs/algo/sac_decoupled.yaml b/sheeprl/configs/algo/sac_decoupled.yaml new file mode 100644 index 00000000..2bf88e0e --- /dev/null +++ b/sheeprl/configs/algo/sac_decoupled.yaml @@ -0,0 +1,5 @@ +defaults: + - sac + - _self_ + +name: sac_decoupled diff --git a/sheeprl/configs/config.yaml b/sheeprl/configs/config.yaml index 607ca96e..24cfc2b5 100644 --- a/sheeprl/configs/config.yaml +++ b/sheeprl/configs/config.yaml @@ -7,9 +7,10 @@ defaults: - buffer: default.yaml - checkpoint: default.yaml - env: default.yaml + - fabric: default.yaml - metric: default.yaml - - exp: null - hydra: default.yaml + - exp: ??? num_threads: 1 total_steps: ??? diff --git a/sheeprl/configs/env/atari.yaml b/sheeprl/configs/env/atari.yaml index 66933688..38dfa3cc 100644 --- a/sheeprl/configs/env/atari.yaml +++ b/sheeprl/configs/env/atari.yaml @@ -2,11 +2,14 @@ defaults: - default - _self_ -id: PongNoFrameskip-v4 +# Override from `default` config action_repeat: 4 +id: PongNoFrameskip-v4 max_episode_steps: 27000 -env: - _target_: gymnasium.wrappers.AtariPreprocessing + +# Wrapper to be instantiated +wrapper: + _target_: gymnasium.wrappers.AtariPreprocessing # https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.AtariPreprocessing env: _target_: gymnasium.make id: ${env.id} diff --git a/sheeprl/configs/env/default.yaml b/sheeprl/configs/env/default.yaml index 6635a685..552c545a 100644 --- a/sheeprl/configs/env/default.yaml +++ b/sheeprl/configs/env/default.yaml @@ -1,6 +1,7 @@ +id: ??? num_envs: 4 -sync_env: False frame_stack: 1 +sync_env: False screen_size: 64 action_repeat: 1 grayscale: False diff --git a/sheeprl/configs/env/diambra.yaml b/sheeprl/configs/env/diambra.yaml index d13d87ef..385d2059 100644 --- a/sheeprl/configs/env/diambra.yaml +++ b/sheeprl/configs/env/diambra.yaml @@ -2,12 +2,14 @@ defaults: - default - _self_ +# Override from `default` config id: doapp -action_repeat: 1 frame_stack: 4 sync_env: True +action_repeat: 1 -env: +# Wrapper to be instantiated +wrapper: _target_: sheeprl.envs.diambra.DiambraWrapper id: ${env.id} action_space: discrete diff --git a/sheeprl/configs/env/dmc.yaml b/sheeprl/configs/env/dmc.yaml index 2dcf3eb8..bda7e6f5 100644 --- a/sheeprl/configs/env/dmc.yaml +++ b/sheeprl/configs/env/dmc.yaml @@ -2,11 +2,13 @@ defaults: - default - _self_ +# Override from `default` config id: walker_walk action_repeat: 1 max_episode_steps: 1000 -env: +# Wrapper to be instantiated +wrapper: _target_: sheeprl.envs.dmc.DMCWrapper id: ${env.id} width: ${env.screen_size} diff --git a/sheeprl/configs/env/dummy.yaml b/sheeprl/configs/env/dummy.yaml index c3751640..d7285378 100644 --- a/sheeprl/configs/env/dummy.yaml +++ b/sheeprl/configs/env/dummy.yaml @@ -2,8 +2,10 @@ defaults: - default - _self_ +# Override from `default` config id: discrete_dummy -env: +# Wrapper to be instantiated +wrapper: _target_: sheeprl.utils.env.get_dummy_env id: ${env.id} diff --git a/sheeprl/configs/env/gym.yaml b/sheeprl/configs/env/gym.yaml index 3def8c95..c45f9aa8 100644 --- a/sheeprl/configs/env/gym.yaml +++ b/sheeprl/configs/env/gym.yaml @@ -2,10 +2,12 @@ defaults: - default - _self_ +# Override from `default` config id: CartPole-v1 mask_velocities: False -env: +# Wrapper to be instantiated +wrapper: _target_: gymnasium.make id: ${env.id} render_mode: rgb_array diff --git a/sheeprl/configs/env/minecraft.yaml b/sheeprl/configs/env/minecraft.yaml index 21f6bd17..0da9fd48 100644 --- a/sheeprl/configs/env/minecraft.yaml +++ b/sheeprl/configs/env/minecraft.yaml @@ -2,8 +2,8 @@ defaults: - default - _self_ -min_pitch: -60 max_pitch: 60 -break_speed_multiplier: 100 +min_pitch: -60 +sticky_jump: 10 sticky_attack: 30 -sticky_jump: 10 \ No newline at end of file +break_speed_multiplier: 100 diff --git a/sheeprl/configs/env/minedojo.yaml b/sheeprl/configs/env/minedojo.yaml index 26b1d801..fb79138b 100644 --- a/sheeprl/configs/env/minedojo.yaml +++ b/sheeprl/configs/env/minedojo.yaml @@ -2,10 +2,12 @@ defaults: - minecraft - _self_ +# Override from `minecraft` config id: open-ended action_repeat: 1 -env: +# Wrapper to be instantiated +wrapper: _target_: sheeprl.envs.minedojo.MineDojoWrapper id: ${env.id} height: ${env.screen_size} diff --git a/sheeprl/configs/env/minerl.yaml b/sheeprl/configs/env/minerl.yaml index fff9c772..74e564a6 100644 --- a/sheeprl/configs/env/minerl.yaml +++ b/sheeprl/configs/env/minerl.yaml @@ -2,10 +2,12 @@ defaults: - minecraft - _self_ +# Override from `minecraft` config id: custom_navigate action_repeat: 1 -env: +# Wrapper to be instantiated +wrapper: _target_: sheeprl.envs.minerl.MineRLWrapper id: ${env.id} height: ${env.screen_size} diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml index f4e844db..ed40f8fe 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml @@ -14,7 +14,7 @@ env: id: doapp num_envs: 8 frame_stack: 1 - env: + wrapper: diambra_settings: characters: Kasumi diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml index 70905a27..6f912a9c 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml @@ -18,7 +18,7 @@ env: frame_stack: 1 screen_size: 128 reward_as_observation: True - env: + wrapper: attack_but_combination: True diambra_settings: characters: Kasumi @@ -81,4 +81,4 @@ algo: # Metric metric: - log_every: 10000 \ No newline at end of file + log_every: 10000 diff --git a/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml b/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml index 10927613..7680f147 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml @@ -14,7 +14,7 @@ env: num_envs: 4 id: custom_navigate reward_as_observation: True - env: + wrapper: multihot_inventory: False # Checkpoint diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml index a683a28e..66bcc776 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -20,7 +20,7 @@ env: num_envs: 1 max_episode_steps: 1000 id: walker_walk - env: + wrapper: from_vectors: True from_pixels: True @@ -52,4 +52,4 @@ algo: # Metric metric: - log_every: 5000 \ No newline at end of file + log_every: 5000 diff --git a/sheeprl/configs/exp/ppo_decoupled.yaml b/sheeprl/configs/exp/ppo_decoupled.yaml new file mode 100644 index 00000000..55914cb0 --- /dev/null +++ b/sheeprl/configs/exp/ppo_decoupled.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +defaults: + - ppo + - override /algo: ppo_decoupled + - override /env: gym + - _self_ diff --git a/sheeprl/configs/exp/sac_decoupled.yaml b/sheeprl/configs/exp/sac_decoupled.yaml new file mode 100644 index 00000000..7e1714fb --- /dev/null +++ b/sheeprl/configs/exp/sac_decoupled.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +defaults: + - sac + - override /algo: sac_decoupled + - override /env: gym + - _self_ diff --git a/sheeprl/configs/fabric/ddp-cpu.yaml b/sheeprl/configs/fabric/ddp-cpu.yaml new file mode 100644 index 00000000..8ab4824f --- /dev/null +++ b/sheeprl/configs/fabric/ddp-cpu.yaml @@ -0,0 +1,7 @@ +defaults: + - default + - _self_ + +devices: 0 +strategy: "ddp" +accelerator: "cpu" diff --git a/sheeprl/configs/fabric/ddp-cuda.yaml b/sheeprl/configs/fabric/ddp-cuda.yaml new file mode 100644 index 00000000..3a533ec9 --- /dev/null +++ b/sheeprl/configs/fabric/ddp-cuda.yaml @@ -0,0 +1,7 @@ +defaults: + - default + - _self_ + +devices: 0 +strategy: "ddp" +accelerator: "cuda" diff --git a/sheeprl/configs/fabric/default.yaml b/sheeprl/configs/fabric/default.yaml new file mode 100644 index 00000000..ca406b2a --- /dev/null +++ b/sheeprl/configs/fabric/default.yaml @@ -0,0 +1,5 @@ +devices: 1 +num_nodes: 1 +strategy: "auto" +accelerator: "cpu" +precision: "32-true" diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index fe25d4ae..22a13093 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -81,11 +81,11 @@ def thunk() -> gym.Env: env_spec = "" instantiate_kwargs = {} - if "seed" in cfg.env.env: + if "seed" in cfg.env.wrapper: instantiate_kwargs["seed"] = seed - if "rank" in cfg.env.env: + if "rank" in cfg.env.wrapper: instantiate_kwargs["rank"] = rank + vector_env_idx - env = hydra.utils.instantiate(cfg.env.env, **instantiate_kwargs) + env = hydra.utils.instantiate(cfg.env.wrapper, **instantiate_kwargs) # action repeat if ( diff --git a/sheeprl/utils/logger.py b/sheeprl/utils/logger.py index 7245745b..5698ba99 100644 --- a/sheeprl/utils/logger.py +++ b/sheeprl/utils/logger.py @@ -9,9 +9,7 @@ from omegaconf import DictConfig -def create_tensorboard_logger( - fabric: Fabric, cfg: DictConfig, algo_name: str -) -> Tuple[Optional[TensorBoardLogger], str]: +def create_tensorboard_logger(fabric: Fabric, cfg: DictConfig) -> Tuple[Optional[TensorBoardLogger], str]: # Set logger only on rank-0 but share the logger directory: since we don't know # what is happening during the `fabric.save()` method, at least we assure that all # ranks save under the same named folder. @@ -24,7 +22,7 @@ def create_tensorboard_logger( root_dir = ( os.path.join("logs", "runs", cfg.root_dir) if cfg.root_dir is not None - else os.path.join("logs", "runs", algo_name, datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) + else os.path.join("logs", "runs", cfg.algo.name, datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) ) run_name = ( cfg.run_name if cfg.run_name is not None else f"{cfg.env.id}_{cfg.exp_name}_{cfg.seed}_{int(time.time())}" diff --git a/sheeprl/utils/registry.py b/sheeprl/utils/registry.py index af4995c1..9bd3122e 100644 --- a/sheeprl/utils/registry.py +++ b/sheeprl/utils/registry.py @@ -2,33 +2,30 @@ from typing import Any, Callable, Dict, List # Mapping of tasks with their relative algorithms. -# A new task can be added as: tasks[module] = [..., algorithm] -# where `module` and `algorithm` are respectively taken from sheeprl/algos/{module}/{algorithm}.py -tasks: Dict[str, List[str]] = {} - -# A list representing the `decoupled` algorithms -decoupled_tasks: List[str] = [] +# A new task can be added as: +# tasks[module] = [..., {"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] +# where `module` and `algorithm` are respectively taken from sheeprl/algos/{module}/{algorithm}.py, +# while `entrypoint` is the decorated function +tasks: Dict[str, List[Dict[str, str]]] = {} def _register(fn: Callable[..., Any], decoupled: bool = False) -> Callable[..., Any]: # lookup containing module if fn.__module__ == "__main__": return fn + entrypoint = fn.__name__ module_split = fn.__module__.split(".") algorithm = module_split[-1] module = ".".join(module_split[:-1]) algos = tasks.get(module, None) if algos is None: - tasks[module] = [algorithm] + tasks[module] = [{"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] else: if algorithm in algos: raise ValueError(f"The algorithm `{algorithm}` has already been registered!") - tasks[module].append(algorithm) - if decoupled or "decoupled" in algorithm: - decoupled_tasks.append(algorithm) + tasks[module].append({"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}) # add the decorated function to __all__ in algorithm - entrypoint = fn.__name__ mod = sys.modules[fn.__module__] if hasattr(mod, "__all__"): mod.__all__.append(entrypoint) diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index aeb72d04..d580febf 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -141,7 +141,7 @@ def symexp(x: Tensor) -> Tensor: @rank_zero_only def print_config( config: DictConfig, - fields: Sequence[str] = ("algo", "buffer", "checkpoint", "env", "exp", "hydra", "metric", "optim"), + fields: Sequence[str] = ("algo", "buffer", "checkpoint", "env", "fabric", "metric"), resolve: bool = True, cfg_save_path: Optional[Union[str, os.PathLike]] = None, ) -> None: diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 797341d3..35e8e50e 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -1,10 +1,9 @@ -import importlib import os import shutil import sys import time import warnings -from contextlib import closing, nullcontext +from contextlib import nullcontext from pathlib import Path from unittest import mock @@ -12,6 +11,8 @@ import torch.distributed as dist from lightning import Fabric +from sheeprl import ROOT_DIR +from sheeprl.cli import run from sheeprl.utils.imports import _IS_WINDOWS @@ -23,10 +24,12 @@ def devices(request): @pytest.fixture() def standard_args(): return [ + os.path.join(ROOT_DIR, "__main__.py"), "hydra/job_logging=disabled", "hydra/hydra_logging=disabled", "dry_run=True", "env.num_envs=1", + "fabric.devices=auto", f"env.sync_env={_IS_WINDOWS}", ] @@ -75,7 +78,6 @@ def remove_test_dir(path: str) -> None: @pytest.mark.timeout(60) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_droq(standard_args, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.droq.droq") root_dir = os.path.join(f"pytest_{start_time}", "droq", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) @@ -93,10 +95,8 @@ def test_droq(standard_args, checkpoint_buffer, start_time): "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + with mock.patch.object(sys, "argv", args): + run() keys = { "agent", @@ -117,7 +117,6 @@ def test_droq(standard_args, checkpoint_buffer, start_time): @pytest.mark.timeout(60) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_sac(standard_args, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.sac.sac") root_dir = os.path.join(f"pytest_{start_time}", "sac", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) @@ -135,10 +134,8 @@ def test_sac(standard_args, checkpoint_buffer, start_time): "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + with mock.patch.object(sys, "argv", args): + run() keys = { "agent", @@ -159,7 +156,6 @@ def test_sac(standard_args, checkpoint_buffer, start_time): @pytest.mark.timeout(60) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_sac_ae(standard_args, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.sac_ae.sac_ae") root_dir = os.path.join(f"pytest_{start_time}", "sac_ae", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) @@ -185,10 +181,8 @@ def test_sac_ae(standard_args, checkpoint_buffer, start_time): "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + with mock.patch.object(sys, "argv", args): + run() keys = { "agent", @@ -213,44 +207,26 @@ def test_sac_ae(standard_args, checkpoint_buffer, start_time): @pytest.mark.timeout(60) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_sac_decoupled(standard_args, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.sac.sac_decoupled") root_dir = os.path.join(f"pytest_{start_time}", "sac_decoupled", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) version = 0 if not os.path.isdir(ckpt_path) else len(os.listdir(ckpt_path)) ckpt_path = os.path.join(ckpt_path, f"version_{version}", "checkpoint") args = standard_args + [ - "exp=sac", + "exp=sac_decoupled", "per_rank_batch_size=1", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", + f"fabric.devices={os.environ['LT_DEVICES']}", f"root_dir={root_dir}", f"run_name={run_name}", f"buffer.checkpoint={checkpoint_buffer}", "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - import torch.distributed.run as torchrun - from torch.distributed.elastic.multiprocessing.errors import ChildFailedError - from torch.distributed.elastic.utils import get_socket_with_port - - sock = get_socket_with_port() - with closing(sock): - master_port = sock.getsockname()[1] - - for command in task.__all__: - if command == "main": - with pytest.raises(ChildFailedError) if os.environ["LT_DEVICES"] == "1" else nullcontext(): - torchrun_args = [ - f"--nproc_per_node={os.environ['LT_DEVICES']}", - "--nnodes=1", - "--node-rank=0", - "--start-method=spawn", - "--master-addr=localhost", - f"--master-port={master_port}", - ] + sys.argv - torchrun.main(torchrun_args) + with mock.patch.object(sys, "argv", args): + with pytest.raises(RuntimeError) if os.environ["LT_DEVICES"] == "1" else nullcontext(): + run() if os.environ["LT_DEVICES"] != "1": keys = { @@ -272,7 +248,6 @@ def test_sac_decoupled(standard_args, checkpoint_buffer, start_time): @pytest.mark.timeout(60) @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) def test_ppo(standard_args, start_time, env_id): - task = importlib.import_module("sheeprl.algos.ppo.ppo") root_dir = os.path.join(f"pytest_{start_time}", "ppo", os.environ["LT_DEVICES"]) run_name = "test_ppo" ckpt_path = os.path.join(root_dir, run_name) @@ -288,10 +263,9 @@ def test_ppo(standard_args, start_time, env_id): f"env.id={env_id}", "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + + with mock.patch.object(sys, "argv", args): + run() check_checkpoint( Path(os.path.join("logs", "runs", ckpt_path)), @@ -303,15 +277,15 @@ def test_ppo(standard_args, start_time, env_id): @pytest.mark.timeout(60) @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) def test_ppo_decoupled(standard_args, start_time, env_id): - task = importlib.import_module("sheeprl.algos.ppo.ppo_decoupled") root_dir = os.path.join(f"pytest_{start_time}", "ppo_decoupled", os.environ["LT_DEVICES"]) run_name = "test_ppo_decoupled" ckpt_path = os.path.join(root_dir, run_name) version = 0 if not os.path.isdir(ckpt_path) else len(os.listdir(ckpt_path)) ckpt_path = os.path.join(ckpt_path, f"version_{version}", "checkpoint") args = standard_args + [ - "exp=ppo", + "exp=ppo_decoupled", "env=dummy", + f"fabric.devices={os.environ['LT_DEVICES']}", f"algo.rollout_steps={os.environ['LT_DEVICES']}", "per_rank_batch_size=1", "algo.update_epochs=1", @@ -320,27 +294,10 @@ def test_ppo_decoupled(standard_args, start_time, env_id): f"env.id={env_id}", "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - import torch.distributed.run as torchrun - from torch.distributed.elastic.multiprocessing.errors import ChildFailedError - from torch.distributed.elastic.utils import get_socket_with_port - - sock = get_socket_with_port() - with closing(sock): - master_port = sock.getsockname()[1] - - for command in task.__all__: - if command == "main": - with pytest.raises(ChildFailedError) if os.environ["LT_DEVICES"] == "1" else nullcontext(): - torchrun_args = [ - f"--nproc_per_node={os.environ['LT_DEVICES']}", - "--nnodes=1", - "--node-rank=0", - "--start-method=spawn", - "--master-addr=localhost", - f"--master-port={master_port}", - ] + sys.argv - torchrun.main(torchrun_args) + + with mock.patch.object(sys, "argv", args): + with pytest.raises(RuntimeError) if os.environ["LT_DEVICES"] == "1" else nullcontext(): + run() if os.environ["LT_DEVICES"] != "1": check_checkpoint( @@ -350,9 +307,8 @@ def test_ppo_decoupled(standard_args, start_time, env_id): remove_test_dir(os.path.join("logs", "runs", f"pytest_{start_time}")) -@pytest.mark.timeout(60) +# @pytest.mark.timeout(60) def test_ppo_recurrent(standard_args, start_time): - task = importlib.import_module("sheeprl.algos.ppo_recurrent.ppo_recurrent") root_dir = os.path.join(f"pytest_{start_time}", "ppo_recurrent", os.environ["LT_DEVICES"]) run_name = "test_ppo_recurrent" ckpt_path = os.path.join(root_dir, run_name) @@ -366,10 +322,9 @@ def test_ppo_recurrent(standard_args, start_time): f"run_name={run_name}", "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + + with mock.patch.object(sys, "argv", args): + run() check_checkpoint( Path(os.path.join("logs", "runs", ckpt_path)), @@ -382,7 +337,6 @@ def test_ppo_recurrent(standard_args, start_time): @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_dreamer_v1(standard_args, env_id, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.dreamer_v1.dreamer_v1") root_dir = os.path.join(f"pytest_{start_time}", "dreamer_v1", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) @@ -407,10 +361,8 @@ def test_dreamer_v1(standard_args, env_id, checkpoint_buffer, start_time): "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + with mock.patch.object(sys, "argv", args): + run() keys = { "world_model", @@ -429,14 +381,13 @@ def test_dreamer_v1(standard_args, env_id, checkpoint_buffer, start_time): keys.add("rb") check_checkpoint(Path(os.path.join("logs", "runs", ckpt_path)), keys, checkpoint_buffer) - # shutil.rmtree(f"logs/runs/pytest_{start_time}") + remove_test_dir(os.path.join("logs", "runs", f"pytest_{start_time}")) @pytest.mark.timeout(60) @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_p2e_dv1(standard_args, env_id, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.p2e_dv1.p2e_dv1") root_dir = os.path.join(f"pytest_{start_time}", "p2e_dv1", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) @@ -461,10 +412,8 @@ def test_p2e_dv1(standard_args, env_id, checkpoint_buffer, start_time): "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + with mock.patch.object(sys, "argv", args): + run() keys = { "world_model", @@ -495,7 +444,6 @@ def test_p2e_dv1(standard_args, env_id, checkpoint_buffer, start_time): @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.p2e_dv2.p2e_dv2") root_dir = os.path.join(f"pytest_{start_time}", "p2e_dv2", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) @@ -524,10 +472,8 @@ def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + with mock.patch.object(sys, "argv", args): + run() keys = { "world_model", @@ -560,7 +506,6 @@ def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_dreamer_v2(standard_args, env_id, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.dreamer_v2.dreamer_v2") root_dir = os.path.join(f"pytest_{start_time}", "dreamer_v2", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) @@ -590,10 +535,8 @@ def test_dreamer_v2(standard_args, env_id, checkpoint_buffer, start_time): "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + with mock.patch.object(sys, "argv", args): + run() keys = { "world_model", @@ -619,7 +562,6 @@ def test_dreamer_v2(standard_args, env_id, checkpoint_buffer, start_time): @pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) @pytest.mark.parametrize("checkpoint_buffer", [True, False]) def test_dreamer_v3(standard_args, env_id, checkpoint_buffer, start_time): - task = importlib.import_module("sheeprl.algos.dreamer_v3.dreamer_v3") root_dir = os.path.join("pytest_" + start_time, "dreamer_v3", os.environ["LT_DEVICES"]) run_name = "checkpoint_buffer" if checkpoint_buffer else "no_checkpoint_buffer" ckpt_path = os.path.join(root_dir, run_name) @@ -649,10 +591,8 @@ def test_dreamer_v3(standard_args, env_id, checkpoint_buffer, start_time): "env.capture_video=False", ] - with mock.patch.object(sys, "argv", [task.__file__] + args): - for command in task.__all__: - if command == "main": - task.__dict__[command]() + with mock.patch.object(sys, "argv", args): + run() keys = { "world_model", diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 1931ef4c..1ccc188d 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -10,7 +10,7 @@ def test_fsdp_strategy_fail(): with pytest.raises(subprocess.CalledProcessError): subprocess.run( - "lightning run model --strategy=fsdp --devices=1 sheeprl.py ppo", + sys.executable + " sheeprl.py exp=ppo fabric.strategy=fsdp", shell=True, check=True, ) @@ -18,7 +18,7 @@ def test_fsdp_strategy_fail(): def test_run_decoupled_algo(): subprocess.run( - "lightning run model --strategy=ddp --devices=2 sheeprl.py ppo_decoupled " + sys.executable + " sheeprl.py exp=ppo_decoupled fabric.strategy=ddp fabric.devices=2 " "exp=ppo dry_run=True algo.rollout_steps=1 cnn_keys.encoder=[rgb] mlp_keys.encoder=[state] " "env.capture_video=False", shell=True, @@ -29,7 +29,7 @@ def test_run_decoupled_algo(): def test_run_algo(): subprocess.run( sys.executable - + " sheeprl.py ppo exp=ppo dry_run=True algo.rollout_steps=1 cnn_keys.encoder=[rgb] mlp_keys.encoder=[state] " + + " sheeprl.py exp=ppo dry_run=True algo.rollout_steps=1 cnn_keys.encoder=[rgb] mlp_keys.encoder=[state] " "env.capture_video=False", shell=True, check=True, @@ -41,7 +41,7 @@ def test_resume_from_checkpoint(): run_name = "test_ckpt" subprocess.run( sys.executable - + " sheeprl.py dreamer_v3 exp=dreamer_v3 env=dummy dry_run=True " + + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " + "env.capture_video=False algo.dense_units=8 algo.horizon=8 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.recurrent_model.recurrent_state_size=8 " @@ -58,7 +58,7 @@ def test_resume_from_checkpoint(): ckpt_path = os.path.join(ckpt_path, ckpt_file_name) subprocess.run( sys.executable - + f" sheeprl.py dreamer_v3 exp=dreamer_v3 checkpoint.resume_from={ckpt_path} " + + f" sheeprl.py exp=dreamer_v3 checkpoint.resume_from={ckpt_path} " + "root_dir=pytest_resume_ckpt run_name=test_resume", shell=True, check=True,