diff --git a/examples/architecture_template.py b/examples/architecture_template.py index e4179dd1..bf3fdf37 100644 --- a/examples/architecture_template.py +++ b/examples/architecture_template.py @@ -139,7 +139,7 @@ def main(): if devices is None or devices in ("1", "2"): raise RuntimeError( "Please run the script with the number of devices greater than 2: " - "`lightning run model --devices=3 sheeprl.py ...`" + "`lightning run model --devices=3 examples/architecture_template.py ...`" ) world_collective = TorchCollective() diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 3ee1534c..d8f9e4ef 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -129,7 +129,6 @@ def main(cfg: DictConfig): { "Rewards/rew_avg": MeanMetric(), "Game/ep_len_avg": MeanMetric(), - "Time/step_per_second": MeanMetric(), "Loss/value_loss": MeanMetric(), "Loss/policy_loss": MeanMetric(), "Loss/entropy_loss": MeanMetric(), @@ -222,7 +221,6 @@ def main(cfg: DictConfig): # Log metrics metrics_dict = aggregator.compute() - fabric.log("Time/step_per_second", int(global_step / (time.perf_counter() - start_time)), global_step) fabric.log_dict(metrics_dict, global_step) aggregator.reset() diff --git a/pyproject.toml b/pyproject.toml index 8d116d81..82ab0312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,9 +26,10 @@ dependencies = [ "tensorboard>=2.10", "python-dotenv>=1.0.0", "lightning==2.0.*", - "lightning-utilities<0.9", + "lightning-utilities<=0.9", "hydra-core==1.3.0", "torchmetrics==1.1.*", + "rich==13.5.*", "opencv-python==4.8.0.*" ] dynamic = ["version"] diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index d28f5ad6..1eea1737 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -1,7 +1,6 @@ import copy import os import pathlib -import time import warnings from typing import Dict @@ -18,7 +17,7 @@ from tensordict.tensordict import TensorDictBase from torch.distributions import Bernoulli, Independent, Normal from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_models from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss @@ -29,7 +28,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import compute_lambda_values, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import compute_lambda_values, polynomial_decay, print_config # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -212,7 +212,7 @@ def train( ) world_optimizer.step() aggregator.update("Grads/world_model", world_model_grads.mean().detach()) - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -359,6 +359,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 @@ -367,8 +369,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -379,7 +382,7 @@ def main(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") cfg.checkpoint.resume_from = str(ckpt_path) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir cfg.run_name = f"resume_from_checkpoint_{run_name}" @@ -479,32 +482,29 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) - aggregator.to(fabric.device) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(fabric.device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 2 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2 rb = AsyncReplayBuffer( buffer_size, cfg.env.num_envs, @@ -514,28 +514,29 @@ def main(cfg: DictConfig): sequential=True, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: - if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: - raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=fabric.device if cfg.buffer.memmap else "cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables - start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + train_step = 0 + last_train = 0 + start_step = state["update"] // world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -547,14 +548,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -576,47 +577,50 @@ def main(cfg: DictConfig): player.init_states() for update in range(start_step, num_updates + 1): - # Sample an action given the observation received by the environment - if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + policy_step += cfg.env.num_envs * world_size + + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + # Sample an action given the observation received by the environment + if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - - policy_step += cfg.env.num_envs * fabric.world_size + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) @@ -673,19 +677,22 @@ def main(cfg: DictConfig): n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - train( - fabric, - world_model, - actor, - critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - ) + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + train( + fabric, + world_model, + actor, + critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + ) + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -696,12 +703,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) @@ -717,8 +745,8 @@ def main(cfg: DictConfig): "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, - "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 6fcafb2b..0d7217eb 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -5,7 +5,6 @@ import copy import os import pathlib -import time import warnings from typing import Dict, Sequence @@ -24,7 +23,7 @@ from torch.distributions import Bernoulli, Distribution, Independent, Normal, OneHotCategorical from torch.optim import Optimizer from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_models from sheeprl.algos.dreamer_v2.loss import reconstruction_loss @@ -35,7 +34,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import polynomial_decay, print_config # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -207,7 +207,7 @@ def train( ) world_optimizer.step() aggregator.update("Grads/world_model", world_model_grads.mean().detach()) - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -379,6 +379,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 @@ -387,8 +389,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -399,7 +402,7 @@ def main(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") cfg.checkpoint.resume_from = str(ckpt_path) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir cfg.run_name = f"resume_from_checkpoint_{run_name}" @@ -500,32 +503,29 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) - aggregator.to(fabric.device) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(fabric.device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 2 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2 buffer_type = cfg.buffer.type.lower() if buffer_type == "sequential": rb = AsyncReplayBuffer( @@ -547,28 +547,29 @@ def main(cfg: DictConfig): else: raise ValueError(f"Unrecognized buffer type: must be one of `sequential` or `episode`, received: {buffer_type}") if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: - if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] elif isinstance(state["rb"], (AsyncReplayBuffer, EpisodeBuffer)): rb = state["rb"] else: - raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables - start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + train_step = 0 + last_train = 0 + start_step = state["update"] // world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -580,14 +581,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -618,55 +619,58 @@ def main(cfg: DictConfig): per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.algo.actor.cls.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + policy_step += cfg.env.num_envs * world_size + + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.algo.actor.cls.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - if cfg.dry_run and buffer_type == "episode": - dones = np.ones_like(dones) - - policy_step += cfg.env.num_envs * fabric.world_size + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + + step_data["is_first"] = copy.deepcopy(step_data["dones"]) + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) + if cfg.dry_run and buffer_type == "episode": + dones = np.ones_like(dones) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) @@ -746,25 +750,27 @@ def main(cfg: DictConfig): prioritize_ends=cfg.buffer.prioritize_ends, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): - tcp.data.copy_(cp.data) - train( - fabric, - world_model, - actor, - critic, - target_critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - actions_dim, - ) - per_rank_gradient_steps += 1 + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): + tcp.data.copy_(cp.data) + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + actions_dim, + ) + per_rank_gradient_steps += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -775,12 +781,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) @@ -797,8 +824,8 @@ def main(cfg: DictConfig): "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, - "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 07f76a13..e1abe5c4 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -5,7 +5,6 @@ import copy import os import pathlib -import time import warnings from functools import partial from typing import Dict, Sequence @@ -25,7 +24,7 @@ from torch.distributions import Bernoulli, Distribution, Independent, OneHotCategorical from torch.optim import Optimizer from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_models from sheeprl.algos.dreamer_v3.loss import reconstruction_loss @@ -38,7 +37,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import polynomial_decay, print_config # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -177,7 +177,7 @@ def train( ) world_optimizer.step() aggregator.update("Grads/world_model", world_model_grads.mean().detach()) - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -332,6 +332,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.frame_stack = -1 if 2 ** int(np.log2(cfg.env.screen_size)) != cfg.env.screen_size: @@ -341,8 +343,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -463,29 +466,26 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) - aggregator.to(fabric.device) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(fabric.device) # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 2 @@ -508,11 +508,12 @@ def main(cfg: DictConfig): expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables + train_step = 0 + last_train = 0 start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 @@ -531,14 +532,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -561,46 +562,49 @@ def main(cfg: DictConfig): per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.algo.actor.cls.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255.0 + policy_step += cfg.env.num_envs * world_size + + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.algo.actor.cls.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255.0 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - step_data["actions"] = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rb.add(step_data[None, ...]) + step_data["actions"] = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() + rb.add(step_data[None, ...]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - - policy_step += cfg.env.num_envs * fabric.world_size + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) step_data["is_first"] = torch.zeros_like(step_data["dones"]) if "restart_on_exception" in infos: @@ -614,13 +618,13 @@ def main(cfg: DictConfig): step_data["is_first"][i] = torch.ones_like(step_data["is_first"][i]) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) @@ -681,28 +685,30 @@ def main(cfg: DictConfig): else cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau - for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): - tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) - train( - fabric, - world_model, - actor, - critic, - target_critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - is_continuous, - actions_dim, - moments, - ) - per_rank_gradient_steps += 1 + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau + for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): + tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + is_continuous, + actions_dim, + moments, + ) + per_rank_gradient_steps += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -713,12 +719,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 2eb91434..fdd263cd 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -1,5 +1,4 @@ import os -import time import warnings from math import prod @@ -15,7 +14,7 @@ from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.droq.agent import DROQAgent, DROQCritic from sheeprl.algos.sac.agent import SACActor @@ -27,6 +26,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import print_config def train( @@ -44,86 +45,97 @@ def train( sample = rb.sample( cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs ) - gathered_data = fabric.all_gather(sample.to_dict()) - gathered_data = make_tensordict(gathered_data).view(-1) + critic_data = fabric.all_gather(sample.to_dict()) + critic_data = make_tensordict(critic_data).view(-1) if fabric.world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( - range(len(gathered_data)), + range(len(critic_data)), num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=True, seed=cfg.seed, drop_last=False, ) - sampler: BatchSampler = BatchSampler(sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False) + critic_sampler: BatchSampler = BatchSampler( + sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + ) else: - sampler = BatchSampler(sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False) - - # Update the soft-critic - for batch_idxes in sampler: - data = gathered_data[batch_idxes] - next_target_qf_value = agent.get_next_target_q_values( - data["next_observations"], - data["rewards"], - data["dones"], - cfg.algo.gamma, + critic_sampler = BatchSampler( + sampler=range(len(critic_data)), batch_size=cfg.per_rank_batch_size, drop_last=False ) - for qf_value_idx in range(agent.num_critics): - # Line 8 - Algorithm 2 - qf_loss = F.mse_loss( - agent.get_ith_q_value(data["observations"], data["actions"], qf_value_idx), next_target_qf_value - ) - qf_optimizer.zero_grad(set_to_none=True) - fabric.backward(qf_loss) - qf_optimizer.step() - aggregator.update("Loss/value_loss", qf_loss) - - # Update the target networks with EMA - agent.qfs_target_ema(critic_idx=qf_value_idx) # Sample a different minibatch in a distributed way to update actor and alpha parameter sample = rb.sample(cfg.per_rank_batch_size) - data = fabric.all_gather(sample.to_dict()) - data = make_tensordict(data).view(-1) + actor_data = fabric.all_gather(sample.to_dict()) + actor_data = make_tensordict(actor_data).view(-1) if fabric.world_size > 1: - sampler: DistributedSampler = DistributedSampler( - range(len(data)), + actor_sampler: DistributedSampler = DistributedSampler( + range(len(actor_data)), num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=True, seed=cfg.seed, drop_last=False, ) - data = data[next(iter(sampler))] - - # Update the actor - actions, logprobs = agent.get_actions_and_log_probs(data["observations"]) - qf_values = agent.get_q_values(data["observations"], actions) - min_qf_values = torch.mean(qf_values, dim=-1, keepdim=True) - actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values) - actor_optimizer.zero_grad(set_to_none=True) - fabric.backward(actor_loss) - actor_optimizer.step() - aggregator.update("Loss/policy_loss", actor_loss) - - # Update the entropy value - alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy) - alpha_optimizer.zero_grad(set_to_none=True) - fabric.backward(alpha_loss) - agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad) - alpha_optimizer.step() - aggregator.update("Loss/alpha_loss", alpha_loss) + actor_data = actor_data[next(iter(actor_sampler))] + + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + # Update the soft-critic + for batch_idxes in critic_sampler: + critic_batch_data = critic_data[batch_idxes] + next_target_qf_value = agent.get_next_target_q_values( + critic_batch_data["next_observations"], + critic_batch_data["rewards"], + critic_batch_data["dones"], + cfg.algo.gamma, + ) + for qf_value_idx in range(agent.num_critics): + # Line 8 - Algorithm 2 + qf_loss = F.mse_loss( + agent.get_ith_q_value( + critic_batch_data["observations"], critic_batch_data["actions"], qf_value_idx + ), + next_target_qf_value, + ) + qf_optimizer.zero_grad(set_to_none=True) + fabric.backward(qf_loss) + qf_optimizer.step() + aggregator.update("Loss/value_loss", qf_loss) + + # Update the target networks with EMA + agent.qfs_target_ema(critic_idx=qf_value_idx) + + # Update the actor + actions, logprobs = agent.get_actions_and_log_probs(actor_data["observations"]) + qf_values = agent.get_q_values(actor_data["observations"], actions) + min_qf_values = torch.mean(qf_values, dim=-1, keepdim=True) + actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values) + actor_optimizer.zero_grad(set_to_none=True) + fabric.backward(actor_loss) + actor_optimizer.step() + aggregator.update("Loss/policy_loss", actor_loss) + + # Update the entropy value + alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy) + alpha_optimizer.zero_grad(set_to_none=True) + fabric.backward(alpha_loss) + agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad) + alpha_optimizer.step() + aggregator.update("Loss/alpha_loss", alpha_loss) @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # Initialize Fabric fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -196,17 +208,15 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1 @@ -220,10 +230,11 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -231,14 +242,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -249,23 +260,25 @@ def main(cfg: DictConfig): obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs] for update in range(1, num_updates + 1): - # Sample an action given the observation received by the environment - with torch.no_grad(): - actions, _ = actor.module(obs) - actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) - policy_step += cfg.env.num_envs * fabric.world_size + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with torch.no_grad(): + actions, _ = actor.module(obs) + actions = actions.cpu().numpy() + next_obs, rewards, dones, truncated, infos = envs.step(actions) + dones = np.logical_or(dones, truncated) + if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = next_obs.copy() @@ -295,12 +308,34 @@ def main(cfg: DictConfig): # Train the agent if update > learning_starts: train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + train_step += world_size + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index e5a4c262..f987d558 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -1,7 +1,6 @@ import copy import os import pathlib -import time import warnings from typing import Dict @@ -20,7 +19,7 @@ from torch import nn from torch.distributions import Bernoulli, Independent, Normal from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss @@ -33,7 +32,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import compute_lambda_values, init_weights, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import compute_lambda_values, init_weights, polynomial_decay, print_config # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -170,7 +170,7 @@ def train( ) aggregator.update("Grads/world_model", world_grad.detach()) world_optimizer.step() - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -364,6 +364,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 @@ -372,8 +374,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -384,7 +387,7 @@ def main(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") cfg.checkpoint.resume_from = str(ckpt_path) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir cfg.run_name = f"resume_from_checkpoint_{run_name}" @@ -529,41 +532,38 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/ensemble_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/p_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/q_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Rewards/intrinsic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Values_exploration/predicted_values": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Values_exploration/lambda_values": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/ensemble": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) - aggregator.to(device) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/ensemble_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/p_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/q_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Rewards/intrinsic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Values_exploration/predicted_values": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Values_exploration/lambda_values": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/ensemble": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 4 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 rb = AsyncReplayBuffer( buffer_size, cfg.env.num_envs, @@ -573,22 +573,23 @@ def main(cfg: DictConfig): sequential=True, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: - if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: - raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables - start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + train_step = 0 + last_train = 0 + start_step = state["update"] // world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 @@ -596,7 +597,7 @@ def main(cfg: DictConfig): exploration_updates = min(num_updates, exploration_updates) if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -608,14 +609,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -638,6 +639,8 @@ def main(cfg: DictConfig): is_exploring = True for update in range(start_step, num_updates + 1): + policy_step += cfg.env.num_envs * world_size + if update == exploration_updates: is_exploring = False player.actor = actor_task.module @@ -645,48 +648,49 @@ def main(cfg: DictConfig): if fabric.is_global_zero: test(copy.deepcopy(player), fabric, cfg, "zero-shot") - # Sample an action given the observation received by the environment - if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + # Sample an action given the observation received by the environment + if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - - policy_step += cfg.env.num_envs * fabric.world_size + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) @@ -745,26 +749,29 @@ def main(cfg: DictConfig): n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - train( - fabric, - world_model, - actor_task, - critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critic_exploration=critic_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - critic_exploration_optimizer=critic_exploration_optimizer, - is_exploring=is_exploring, - ) + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + train( + fabric, + world_model, + actor_task, + critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critic_exploration=critic_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + critic_exploration_optimizer=critic_exploration_optimizer, + is_exploring=is_exploring, + ) + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -775,12 +782,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) @@ -798,8 +826,8 @@ def main(cfg: DictConfig): "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, - "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "critic_exploration": critic_exploration.state_dict(), "actor_exploration_optimizer": actor_exploration_optimizer.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index 61fec5a6..d1d48a0b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -1,7 +1,6 @@ import copy import os import pathlib -import time import warnings from typing import Dict, Sequence @@ -20,7 +19,7 @@ from torch import Tensor, nn from torch.distributions import Bernoulli, Distribution, Independent, Normal, OneHotCategorical from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss @@ -33,7 +32,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import polynomial_decay, print_config # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -192,7 +192,7 @@ def train( error_if_nonfinite=False, ) world_optimizer.step() - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -466,6 +466,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 @@ -474,8 +476,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -486,7 +489,7 @@ def main(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") cfg.checkpoint.resume_from = str(ckpt_path) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir cfg.run_name = f"resume_from_checkpoint_{run_name}" @@ -661,8 +664,7 @@ def main(cfg: DictConfig): { "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/policy_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), @@ -690,7 +692,7 @@ def main(cfg: DictConfig): aggregator.to(device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 4 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 buffer_type = cfg.buffer.type.lower() if buffer_type == "sequential": rb = AsyncReplayBuffer( @@ -712,28 +714,29 @@ def main(cfg: DictConfig): else: raise ValueError(f"Unrecognized buffer type: must be one of `sequential` or `episode`, received: {buffer_type}") if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: - if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: - raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables - start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + train_step = 0 + last_train = 0 + start_step = state["update"] // world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -749,14 +752,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -788,6 +791,8 @@ def main(cfg: DictConfig): per_rank_gradient_steps = 0 is_exploring = True for update in range(start_step, num_updates + 1): + policy_step += cfg.env.num_envs * world_size + if update == exploration_updates: is_exploring = False player.actor = actor_task.module @@ -795,51 +800,52 @@ def main(cfg: DictConfig): if fabric.is_global_zero: test(copy.deepcopy(player), fabric, cfg, "zero-shot") - # Sample an action given the observation received by the environment - if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + # Sample an action given the observation received by the environment + if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - if cfg.dry_run and buffer_type == "episode": - dones = np.ones_like(dones) - - policy_step += cfg.env.num_envs * fabric.world_size + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + + step_data["is_first"] = copy.deepcopy(step_data["dones"]) + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) + if cfg.dry_run and buffer_type == "episode": + dones = np.ones_like(dones) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) @@ -919,35 +925,40 @@ def main(cfg: DictConfig): prioritize_ends=cfg.buffer.prioritize_ends, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): - tcp.data.copy_(cp.data) - for cp, tcp in zip(critic_exploration.module.parameters(), target_critic_exploration.parameters()): - tcp.data.copy_(cp.data) - train( - fabric, - world_model, - actor_task, - critic_task, - target_critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critic_exploration=critic_exploration, - target_critic_exploration=target_critic_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - critic_exploration_optimizer=critic_exploration_optimizer, - is_continuous=is_continuous, - actions_dim=actions_dim, - is_exploring=is_exploring, - ) + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): + tcp.data.copy_(cp.data) + for cp, tcp in zip( + critic_exploration.module.parameters(), target_critic_exploration.parameters() + ): + tcp.data.copy_(cp.data) + train( + fabric, + world_model, + actor_task, + critic_task, + target_critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critic_exploration=critic_exploration, + target_critic_exploration=target_critic_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + critic_exploration_optimizer=critic_exploration_optimizer, + is_continuous=is_continuous, + actions_dim=actions_dim, + is_exploring=is_exploring, + ) + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -958,12 +969,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) @@ -982,8 +1014,8 @@ def main(cfg: DictConfig): "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, - "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "critic_exploration": critic_exploration.state_dict(), "target_critic_exploration": target_critic_exploration.state_dict(), diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index fcf4bd0f..632e8ba3 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -1,6 +1,5 @@ import copy import os -import time import warnings from typing import Union @@ -16,7 +15,7 @@ from tensordict.tensordict import TensorDictBase from torch import nn from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -27,7 +26,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config def train( @@ -108,6 +108,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + if "minedojo" in cfg.env.env._target_.lower(): raise ValueError( "MineDojo is not currently supported by PPO agent, since it does not take " @@ -190,17 +192,15 @@ def main(cfg: DictConfig): optimizer = fabric.setup_optimizers(optimizer) # Create a metric aggregator to log the metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data if cfg.buffer.size < cfg.algo.rollout_steps: @@ -220,23 +220,24 @@ def main(cfg: DictConfig): # Global variables last_log = 0 + last_train = 0 + train_step = 0 policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -266,21 +267,24 @@ def main(cfg: DictConfig): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs * world_size - with torch.no_grad(): - # Sample an action given the observation received by the environment - normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys - } - actions, logprobs, _, value = agent(normalized_obs) - if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() - else: - real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) - actions = torch.cat(actions, -1) - - # Single environment step - o, reward, done, truncated, info = envs.step(real_actions) - done = np.logical_or(done, truncated) + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with torch.no_grad(): + # Sample an action given the observation received by the environment + normalized_obs = { + k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys + } + actions, logprobs, _, value = agent.module(normalized_obs) + if is_continuous: + real_actions = torch.cat(actions, -1).cpu().numpy() + else: + real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) + actions = torch.cat(actions, -1) + + # Single environment step + o, reward, done, truncated, info = envs.step(real_actions) + done = np.logical_or(done, truncated) with device: rewards = torch.tensor(reward, dtype=torch.float32).view(cfg.env.num_envs, -1) # [N_envs, 1] @@ -314,20 +318,20 @@ def main(cfg: DictConfig): next_done = done if "final_info" in info: - for i, agent_final_info in enumerate(info["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): normalized_obs = { k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys } - next_values = agent.get_value(normalized_obs) + next_values = agent.module.get_value(normalized_obs) returns, advantages = gae( rb["rewards"], rb["values"], @@ -353,7 +357,9 @@ def main(cfg: DictConfig): else: gathered_data = local_data - train(fabric, agent, optimizer, gathered_data, aggregator, cfg) + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + train(fabric, agent, optimizer, gathered_data, aggregator, cfg) + train_step += world_size if cfg.algo.anneal_lr: fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step) @@ -374,13 +380,31 @@ def main(cfg: DictConfig): ) # Log metrics - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics metrics_dict = aggregator.compute() fabric.log_dict(metrics_dict, policy_step) - fabric.log("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)), policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 4b66a7d5..862ce9dc 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -19,7 +19,7 @@ from tensordict.tensordict import TensorDictBase, make_tensordict from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -29,11 +29,15 @@ from sheeprl.utils.env import make_dict_env from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config @torch.no_grad() def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective): + print_config(cfg) + + # Initialize logger root_dir = ( os.path.join("logs", "runs", cfg.root_dir) if cfg.root_dir is not None @@ -120,14 +124,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, list(agent.parameters())) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - "Time/step_per_second": MeanMetric(sync_on_compute=False), - } - ) + aggregator = MetricAggregator( + {"Rewards/rew_avg": MeanMetric(sync_on_compute=False), "Game/ep_len_avg": MeanMetric(sync_on_compute=False)} + ).to(device) # Local data rb = ReplayBuffer( @@ -140,23 +139,23 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -195,21 +194,24 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs - with torch.no_grad(): - # Sample an action given the observation received by the environment - normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys - } - actions, logprobs, _, value = agent(normalized_obs) - if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() - else: - real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) - actions = torch.cat(actions, -1) + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with torch.no_grad(): + # Sample an action given the observation received by the environment + normalized_obs = { + k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys + } + actions, logprobs, _, value = agent(normalized_obs) + if is_continuous: + real_actions = torch.cat(actions, -1).cpu().numpy() + else: + real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) + actions = torch.cat(actions, -1) - # Single environment step - o, reward, done, truncated, info = envs.step(real_actions) - done = np.logical_or(done, truncated) + # Single environment step + o, reward, done, truncated, info = envs.step(real_actions) + done = np.logical_or(done, truncated) with device: rewards = torch.tensor(reward, dtype=torch.float32).view(cfg.env.num_envs, -1) # [N_envs, 1] @@ -242,13 +244,13 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co next_done = done if "final_info" in info: - for i, agent_final_info in enumerate(info["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) normalized_obs = {k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys} @@ -277,25 +279,35 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co chunks = local_data[perm].split(chunks_sizes) world_collective.scatter_object_list([None], [None] + chunks, src=0) - # Gather metrics from the trainers to be plotted - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - metrics = [None] - player_trainer_collective.broadcast_object_list(metrics, src=1) - # Wait the trainers to finish player_trainer_collective.broadcast(flattened_parameters, src=1) # Convert back the parameters torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, list(agent.parameters())) - # Log metrics - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(metrics[0], policy_step) + # Gather metrics from the trainers + metrics = [None] + player_trainer_collective.broadcast_object_list(metrics, src=1) + metrics = metrics[0] + + # Log metrics + fabric.log_dict(metrics, policy_step) fabric.log_dict(aggregator.compute(), policy_step) aggregator.reset() + # Sync timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run: last_checkpoint = policy_step @@ -337,6 +349,7 @@ def trainer( optimization_pg: CollectibleGroup, ): global_rank = world_collective.rank + group_world_size = world_collective.world_size - 1 # Receive (possibly updated, by the make_dict_env method for example) cfg from the player data = [None] @@ -382,28 +395,23 @@ def trainer( scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) # Metrics - with fabric.device: - aggregator = MetricAggregator( - { - "Loss/value_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/policy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/entropy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - } - ) + aggregator = MetricAggregator( + { + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + } + ).to(device) # Start training update = 0 - initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) - initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) - policy_step = 0 last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 last_checkpoint = 0 + initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) + initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) while True: # Wait for data data = [None] @@ -422,70 +430,87 @@ def trainer( return data = make_tensordict(data, device=device) update += 1 + train_step += group_world_size policy_step += cfg.env.num_envs * cfg.algo.rollout_steps # Prepare sampler indexes = list(range(data.shape[0])) sampler = BatchSampler(RandomSampler(indexes), batch_size=cfg.per_rank_batch_size, drop_last=False) - # The Join context is needed because there can be the possibility - # that some ranks receive less data - with Join([agent._forward_module]): - for _ in range(cfg.algo.update_epochs): - for batch_idxes in sampler: - batch = data[batch_idxes] - normalized_obs = { - k: batch[k] / 255 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k] - for k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - } - _, logprobs, entropy, new_values = agent( - normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1) - ) - - if cfg.algo.normalize_advantages: - batch["advantages"] = normalize_tensor(batch["advantages"]) - - # Policy loss - pg_loss = policy_loss( - logprobs, - batch["logprobs"], - batch["advantages"], - cfg.algo.clip_coef, - cfg.algo.loss_reduction, - ) - - # Value loss - v_loss = value_loss( - new_values, - batch["values"], - batch["returns"], - cfg.algo.clip_coef, - cfg.algo.clip_vloss, - cfg.algo.loss_reduction, - ) - - # Entropy loss - ent_loss = entropy_loss(entropy, cfg.algo.loss_reduction) - - # Equation (9) in the paper - loss = pg_loss + cfg.algo.vf_coef * v_loss + cfg.algo.ent_coef * ent_loss - - optimizer.zero_grad(set_to_none=True) - fabric.backward(loss) - if cfg.algo.max_grad_norm > 0.0: - fabric.clip_gradients(agent, optimizer, max_norm=cfg.algo.max_grad_norm) - optimizer.step() - - # Update metrics - aggregator.update("Loss/policy_loss", pg_loss.detach()) - aggregator.update("Loss/value_loss", v_loss.detach()) - aggregator.update("Loss/entropy_loss", ent_loss.detach()) - - # Send updated weights to the player + # Start training + with timer( + "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) + ): + # The Join context is needed because there can be the possibility + # that some ranks receive less data + with Join([agent._forward_module]): + for _ in range(cfg.algo.update_epochs): + for batch_idxes in sampler: + batch = data[batch_idxes] + normalized_obs = { + k: batch[k] / 255 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k] + for k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + } + _, logprobs, entropy, new_values = agent( + normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1) + ) + + if cfg.algo.normalize_advantages: + batch["advantages"] = normalize_tensor(batch["advantages"]) + + # Policy loss + pg_loss = policy_loss( + logprobs, + batch["logprobs"], + batch["advantages"], + cfg.algo.clip_coef, + cfg.algo.loss_reduction, + ) + + # Value loss + v_loss = value_loss( + new_values, + batch["values"], + batch["returns"], + cfg.algo.clip_coef, + cfg.algo.clip_vloss, + cfg.algo.loss_reduction, + ) + + # Entropy loss + ent_loss = entropy_loss(entropy, cfg.algo.loss_reduction) + + # Equation (9) in the paper + loss = pg_loss + cfg.algo.vf_coef * v_loss + cfg.algo.ent_coef * ent_loss + + optimizer.zero_grad(set_to_none=True) + fabric.backward(loss) + if cfg.algo.max_grad_norm > 0.0: + fabric.clip_gradients(agent, optimizer, max_norm=cfg.algo.max_grad_norm) + optimizer.step() + + # Update metrics + aggregator.update("Loss/policy_loss", pg_loss.detach()) + aggregator.update("Loss/value_loss", v_loss.detach()) + aggregator.update("Loss/entropy_loss", ent_loss.detach()) + + if global_rank == 1: + player_trainer_collective.broadcast( + torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), + src=1, + ) + if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step + # Sync distributed metrics metrics = aggregator.compute() aggregator.reset() + + # Sync distributed timers + timers = timer.compute() + metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]}) + timer.reset() + + # Send metrics to the player if global_rank == 1: if cfg.algo.anneal_lr: metrics["Info/learning_rate"] = scheduler.get_last_lr()[0] @@ -496,11 +521,10 @@ def trainer( player_trainer_collective.broadcast_object_list( [metrics], src=1 ) # Broadcast metrics: fake send with object list between rank-0 and rank-1 - if global_rank == 1: - player_trainer_collective.broadcast( - torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), - src=1, - ) + + # Reset counters + last_log = policy_step + last_train = train_step if cfg.algo.anneal_lr: scheduler.step() diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index fab5f318..a37fd05b 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -1,7 +1,6 @@ import copy import itertools import os -import time import warnings from contextlib import nullcontext from math import prod @@ -19,7 +18,7 @@ from torch.distributed.algorithms.join import Join from torch.distributions import Categorical from torch.utils.data.sampler import BatchSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent @@ -30,7 +29,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config def train( @@ -109,6 +109,7 @@ def train( @register_algorithm(decoupled=True) @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) @@ -173,18 +174,16 @@ def main(cfg: DictConfig): ) optimizer = fabric.setup_optimizers(hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())) - # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + # Create a metric aggregator to log the metrics + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data rb = ReplayBuffer( @@ -197,25 +196,25 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[1, cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - last_log = 0 # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -229,24 +228,27 @@ def main(cfg: DictConfig): with device: # Get the first environment observation and start the optimization - next_obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32).unsqueeze(0) # [1, N_envs, N_obs] - next_done = torch.zeros(1, cfg.env.num_envs, 1, dtype=torch.float32) # [1, N_envs, 1] next_state = agent.initial_states + next_done = torch.zeros(1, cfg.env.num_envs, 1, dtype=torch.float32) # [1, N_envs, 1] + next_obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32).unsqueeze(0) # [1, N_envs, N_obs] for update in range(1, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs * world_size - with torch.no_grad(): - # Sample an action given the observation received by the environment - action_logits, values, state = agent.module(next_obs, state=next_state) - dist = Categorical(logits=action_logits.unsqueeze(-2)) - action = dist.sample() - logprob = dist.log_prob(action) + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with torch.no_grad(): + # Sample an action given the observation received by the environment + action_logits, values, state = agent.module(next_obs, state=next_state) + dist = Categorical(logits=action_logits.unsqueeze(-2)) + action = dist.sample() + logprob = dist.log_prob(action) - # Single environment step - obs, reward, done, truncated, info = envs.step(action.cpu().numpy().reshape(envs.action_space.shape)) - done = np.logical_or(done, truncated) + # Single environment step + obs, reward, done, truncated, info = envs.step(action.cpu().numpy().reshape(envs.action_space.shape)) + done = np.logical_or(done, truncated) with device: obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) # [1, N_envs, N_obs] @@ -279,13 +281,13 @@ def main(cfg: DictConfig): next_state = state if "final_info" in info: - for i, agent_final_info in enumerate(info["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): @@ -309,8 +311,6 @@ def main(cfg: DictConfig): local_data = rb.buffer # Train the agent - - # Prepare data # 1. Split data into episodes (for every environment) episodes: List[TensorDictBase] = [] for env_id in range(cfg.env.num_envs): @@ -333,7 +333,10 @@ def main(cfg: DictConfig): else: sequences = episodes padded_sequences = pad_sequence(sequences, batch_first=False, return_mask=True) # [Seq_len, Num_seq, *] - train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) + + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) + train_step += world_size if cfg.algo.anneal_lr: fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step) @@ -354,13 +357,31 @@ def main(cfg: DictConfig): ) # Log metrics - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics metrics_dict = aggregator.compute() - fabric.log("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)), policy_step) fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 04bc171b..c7f10aea 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -1,5 +1,4 @@ import os -import time import warnings from math import prod from typing import Optional @@ -17,7 +16,7 @@ from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss @@ -28,6 +27,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import print_config def train( @@ -83,12 +84,15 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # Initialize Fabric fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -152,21 +156,19 @@ def main(cfg: DictConfig): hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]), ) - # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + # Create a metric aggregator to log the metrics + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 1 rb = ReplayBuffer( buffer_size, cfg.env.num_envs, @@ -177,25 +179,26 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -205,26 +208,29 @@ def main(cfg: DictConfig): obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32, device=device) # [N_envs, N_obs] for update in range(1, num_updates + 1): - if update < learning_starts: - actions = envs.action_space.sample() - else: - # Sample an action given the observation received by the environment - with torch.no_grad(): - actions, _ = actor.module(obs) - actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) - - policy_step += cfg.env.num_envs * fabric.world_size + policy_step += cfg.env.num_envs * world_size + + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + if update <= learning_starts: + actions = envs.action_space.sample() + else: + # Sample an action given the observation received by the environment + with torch.no_grad(): + actions, _ = actor.module(obs) + actions = actions.cpu().numpy() + next_obs, rewards, dones, truncated, infos = envs.step(actions) + dones = np.logical_or(dones, truncated) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = next_obs.copy() @@ -234,6 +240,7 @@ def main(cfg: DictConfig): real_next_obs[idx] = final_obs with device: + next_obs = torch.tensor(next_obs, dtype=torch.float32) real_next_obs = torch.tensor(real_next_obs, dtype=torch.float32) actions = torch.tensor(actions, dtype=torch.float32).view(cfg.env.num_envs, -1) rewards = torch.tensor(rewards, dtype=torch.float32).view(cfg.env.num_envs, -1) @@ -248,35 +255,38 @@ def main(cfg: DictConfig): rb.add(step_data.unsqueeze(0)) # next_obs becomes the new obs - obs = torch.tensor(next_obs, device=device) + obs = next_obs # Train the agent - if update >= learning_starts - 1: - training_steps = learning_starts if update == learning_starts - 1 else 1 - for _ in range(training_steps): - # We sample one time to reduce the communications between processes - sample = rb.sample( - cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, - sample_next_obs=cfg.buffer.sample_next_obs, - ) # [G*B, 1] - gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] - gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] - if fabric.world_size > 1: - dist_sampler: DistributedSampler = DistributedSampler( - range(len(gathered_data)), - num_replicas=fabric.world_size, - rank=fabric.global_rank, - shuffle=True, - seed=cfg.seed, - drop_last=False, - ) - sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False - ) - else: - sampler = BatchSampler( - sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False - ) + if update >= learning_starts: + training_steps = learning_starts if update == learning_starts else 1 + + # We sample one time to reduce the communications between processes + sample = rb.sample( + training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, + sample_next_obs=cfg.buffer.sample_next_obs, + ) # [G*B, 1] + gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] + gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] + if world_size > 1: + dist_sampler: DistributedSampler = DistributedSampler( + range(len(gathered_data)), + num_replicas=world_size, + rank=fabric.global_rank, + shuffle=True, + seed=cfg.seed, + drop_last=False, + ) + sampler: BatchSampler = BatchSampler( + sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + ) + else: + sampler = BatchSampler( + sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False + ) + + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): for batch_idxes in sampler: train( fabric, @@ -290,12 +300,34 @@ def main(cfg: DictConfig): cfg, policy_steps_per_update, ) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + train_step += world_size + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index b0016b77..b4e21d29 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -18,7 +18,7 @@ from tensordict import TensorDict, make_tensordict from tensordict.tensordict import TensorDictBase from torch.utils.data.sampler import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic from sheeprl.algos.sac.sac import train @@ -28,10 +28,15 @@ from sheeprl.utils.env import make_env from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import print_config @torch.no_grad() def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective): + print_config(cfg) + + # Initialize logger root_dir = ( os.path.join("logs", "runs", cfg.root_dir) if cfg.root_dir is not None @@ -100,14 +105,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, actor.parameters()) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - "Time/step_per_second": MeanMetric(sync_on_compute=False), - } - ) + aggregator = MetricAggregator( + {"Rewards/rew_avg": MeanMetric(sync_on_compute=False), "Game/ep_len_avg": MeanMetric(sync_on_compute=False)} + ).to(device) # Local data buffer_size = cfg.buffer.size // cfg.env.num_envs if not cfg.dry_run else 1 @@ -121,10 +121,10 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() + first_info_sent = False policy_steps_per_update = int(cfg.env.num_envs) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -132,14 +132,14 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -150,26 +150,29 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs] for update in range(1, num_updates + 1): - if update < learning_starts: - actions = envs.action_space.sample() - else: - # Sample an action given the observation received by the environment - with torch.no_grad(): - actions, _ = actor(obs) - actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) - policy_step += cfg.env.num_envs + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + if update <= learning_starts: + actions = envs.action_space.sample() + else: + # Sample an action given the observation received by the environment + with torch.no_grad(): + actions, _ = actor(obs) + actions = actions.cpu().numpy() + next_obs, rewards, dones, truncated, infos = envs.step(actions) + dones = np.logical_or(dones, truncated) + if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = next_obs.copy() @@ -197,6 +200,14 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Send data to the training agents if update >= learning_starts: + # Send local info to the trainers + if not first_info_sent: + world_collective.broadcast_object_list( + [{"update": update, "last_log": last_log, "last_checkpoint": last_checkpoint}], src=0 + ) + first_info_sent = True + + # Sample data to be sent to the trainers training_steps = learning_starts if update == learning_starts else 1 chunks = rb.sample( training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size * (fabric.world_size - 1), @@ -204,26 +215,38 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co ).split(training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size) world_collective.scatter_object_list([None], [None] + chunks, src=0) - # Gather metrics from the trainers to be plotted - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - metrics = [None] - player_trainer_collective.broadcast_object_list(metrics, src=1) - # Wait the trainers to finish player_trainer_collective.broadcast(flattened_parameters, src=1) # Convert back the parameters torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, actor.parameters()) + # Logs trainers-only metrics if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: + # Gather metrics from the trainers + metrics = [None] + player_trainer_collective.broadcast_object_list(metrics, src=1) + + # Log metrics fabric.log_dict(metrics[0], policy_step) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) + # Logs player-only metrics if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step fabric.log_dict(aggregator.compute(), policy_step) aggregator.reset() + # Sync timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + # Checkpoint model if ( update >= learning_starts # otherwise the processes end up deadlocked @@ -273,7 +296,7 @@ def trainer( optimization_pg: CollectibleGroup, ): global_rank = world_collective.rank - global_rank - 1 + group_world_size = world_collective.world_size - 1 # Receive (possibly updated, by the make_dict_env method for example) cfg from the player data = [None] @@ -328,28 +351,29 @@ def trainer( ) # Metrics - with fabric.device: - aggregator = MetricAggregator( - { - "Loss/value_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/policy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/alpha_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - } - ) + aggregator = MetricAggregator( + { + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + } + ).to(device) + + # Receive data from player reagrding the: + # * update + # * last_log + # * last_checkpoint + data = [None] + world_collective.broadcast_object_list(data, src=0) + update = data[0]["update"] + last_log = data[0]["last_log"] + last_checkpoint = data[0]["last_checkpoint"] # Start training + train_step = 0 + last_train = 0 policy_steps_per_update = cfg.env.num_envs - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - update = learning_starts policy_step = update * policy_steps_per_update - last_log = (policy_step // cfg.metric.log_every - 1) * cfg.metric.log_every - last_checkpoint = 0 while True: # Wait for data data = [None] @@ -369,35 +393,50 @@ def trainer( return data = make_tensordict(data, device=device) sampler = BatchSampler(range(len(data)), batch_size=cfg.per_rank_batch_size, drop_last=False) - for batch_idxes in sampler: - train( - fabric, - agent, - actor_optimizer, - qf_optimizer, - alpha_optimizer, - data[batch_idxes], - aggregator, - update, - cfg, - policy_steps_per_update, - group=optimization_pg, + + # Start training + with timer( + "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) + ): + for batch_idxes in sampler: + train( + fabric, + agent, + actor_optimizer, + qf_optimizer, + alpha_optimizer, + data[batch_idxes], + aggregator, + update, + cfg, + policy_steps_per_update, + group=optimization_pg, + ) + train_step += group_world_size + + if global_rank == 1: + player_trainer_collective.broadcast( + torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()), src=1 ) - # Send updated weights to the player if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step + # Sync distributed metrics metrics = aggregator.compute() aggregator.reset() + + # Sync distributed timers + timers = timer.compute() + metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]}) + timer.reset() + if global_rank == 1: player_trainer_collective.broadcast_object_list( [metrics], src=1 ) # Broadcast metrics: fake send with object list between rank-0 and rank-1 - if global_rank == 1: - player_trainer_collective.broadcast( - torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()), src=1 - ) + # Reset counters + last_log = policy_step + last_train = train_step # Checkpoint model on rank-0: send it everything if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run: @@ -411,6 +450,8 @@ def trainer( "update": update, } fabric.call("on_checkpoint_trainer", player_trainer_collective=player_trainer_collective, state=state) + + # Update counters update += 1 policy_step += policy_steps_per_update diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index df66d59d..88a6cba2 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -1,6 +1,5 @@ import copy import os -import time import warnings from math import prod from typing import Optional, Union @@ -22,7 +21,7 @@ from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss from sheeprl.algos.sac_ae.agent import ( @@ -43,6 +42,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import print_config def train( @@ -133,6 +134,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + if "minedojo" in cfg.env.env._target_.lower(): raise ValueError( "MineDojo is not currently supported by SAC-AE agent, since it does not take " @@ -164,8 +167,9 @@ def main(cfg: DictConfig): fabric = Fabric(strategy=strategy, callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -323,18 +327,16 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1 @@ -349,10 +351,11 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=fabric.device if cfg.buffer.memmap else "cpu") # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.time() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -360,14 +363,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -386,26 +389,29 @@ def main(cfg: DictConfig): obs[k] = torch_obs for update in range(1, num_updates + 1): - if update < learning_starts: - actions = envs.action_space.sample() - else: - with torch.no_grad(): - normalized_obs = {k: v / 255 if k in cfg.cnn_keys.encoder else v for k, v in obs.items()} - actions, _ = actor.module(normalized_obs) - actions = actions.cpu().numpy() - o, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) - policy_step += cfg.env.num_envs * fabric.world_size + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + if update < learning_starts: + actions = envs.action_space.sample() + else: + with torch.no_grad(): + normalized_obs = {k: v / 255 if k in cfg.cnn_keys.encoder else v for k, v in obs.items()} + actions, _ = actor.module(normalized_obs) + actions = actions.cpu().numpy() + o, rewards, dones, truncated, infos = envs.step(actions) + dones = np.logical_or(dones, truncated) + if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) @@ -447,30 +453,33 @@ def main(cfg: DictConfig): # Train the agent if update >= learning_starts - 1: training_steps = learning_starts if update == learning_starts - 1 else 1 - for _ in range(training_steps): - # We sample one time to reduce the communications between processes - sample = rb.sample( - cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, - sample_next_obs=cfg.buffer.sample_next_obs, - ) # [G*B, 1] - gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] - gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] - if fabric.world_size > 1: - dist_sampler: DistributedSampler = DistributedSampler( - range(len(gathered_data)), - num_replicas=fabric.world_size, - rank=fabric.global_rank, - shuffle=True, - seed=cfg.seed, - drop_last=False, - ) - sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False - ) - else: - sampler = BatchSampler( - sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False - ) + + # We sample one time to reduce the communications between processes + sample = rb.sample( + training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, + sample_next_obs=cfg.buffer.sample_next_obs, + ) # [G*B, 1] + gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] + gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] + if fabric.world_size > 1: + dist_sampler: DistributedSampler = DistributedSampler( + range(len(gathered_data)), + num_replicas=fabric.world_size, + rank=fabric.global_rank, + shuffle=True, + seed=cfg.seed, + drop_last=False, + ) + sampler: BatchSampler = BatchSampler( + sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + ) + else: + sampler = BatchSampler( + sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False + ) + + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): for batch_idxes in sampler: train( fabric, @@ -488,12 +497,34 @@ def main(cfg: DictConfig): cfg, policy_steps_per_update, ) - aggregator.update("Time/step_per_second", int(policy_step / (time.time() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + train_step += world_size + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run: last_checkpoint >= policy_step diff --git a/sheeprl/configs/hydra/default.yaml b/sheeprl/configs/hydra/default.yaml index b0627fbe..a7eed38a 100644 --- a/sheeprl/configs/hydra/default.yaml +++ b/sheeprl/configs/hydra/default.yaml @@ -1,2 +1,4 @@ run: dir: logs/runs/${root_dir}/${run_name} +job: + chdir: False diff --git a/sheeprl/utils/metric.py b/sheeprl/utils/metric.py index dbf98e00..c1b0b6f2 100644 --- a/sheeprl/utils/metric.py +++ b/sheeprl/utils/metric.py @@ -1,7 +1,11 @@ -from collections import deque -from typing import Any, Dict, Optional, Union +import warnings +from typing import Any, Dict, List, Optional, Union import torch +import torch.distributed as dist +from lightning.fabric.utilities.distributed import _distributed_available +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup from torchmetrics import Metric @@ -64,7 +68,7 @@ def reset(self): for metric in self.metrics.values(): metric.reset() - def to(self, device: Union[str, torch.device] = "cpu") -> None: + def to(self, device: Union[str, torch.device] = "cpu") -> "MetricAggregator": """Move all metrics to the given device Args: device (Union[str, torch.device], optional): Device to move the metrics to. Defaults to "cpu". @@ -72,9 +76,10 @@ def to(self, device: Union[str, torch.device] = "cpu") -> None: if self.metrics: for k, v in self.metrics.items(): self.metrics[k] = v.to(device) + return self @torch.no_grad() - def compute(self) -> Dict[str, torch.Tensor]: + def compute(self) -> Dict[str, List]: """Reduce the metrics to a single value Returns: Reduced metrics @@ -83,54 +88,79 @@ def compute(self) -> Dict[str, torch.Tensor]: if self.metrics: for k, v in self.metrics.items(): reduced = v.compute() - if v._update_called: - reduced_metrics[k] = reduced.tolist() + is_tensor = torch.is_tensor(reduced) + if is_tensor and reduced.numel() == 1: + reduced_metrics[k] = reduced.item() + else: + if not is_tensor: + warnings.warn( + f"The reduced metric {k} is not a scalar tensor: type={type(reduced)}. " + "This may create problems during the logging phase.", + category=RuntimeWarning, + ) + else: + warnings.warn( + f"The reduced metric {k} is not a scalar: size={v.size()}. " + "This may create problems during the logging phase.", + category=RuntimeWarning, + ) + reduced_metrics[k] = reduced return reduced_metrics -class MovingAverageMetric(Metric): - """Metric for tracking moving average of a value. - - Args: - name (str): Name of the metric - window_size (int): Window size for computing moving average - device (str): Device to store the metric - """ - - def __init__(self, name: str, window_size: int = 100, device: str = "cpu") -> None: - super().__init__(sync_on_compute=False) - self.window_size = window_size - self._values = deque(maxlen=window_size) - self._sum = torch.tensor(0.0, device=self._device) - - def update(self, value: Union[torch.Tensor, float]) -> None: - """Update the moving average with a new value. +class RankIndependentMetricAggregator: + def __init__( + self, + metrics: Union[Dict[str, Metric], MetricAggregator], + process_group: Optional[ProcessGroup] = None, + ) -> None: + """Rank-independent MetricAggregator. + This metric is useful when one wants to maintain per-rank-independent metrics of some quantities, + while still being able to broadcast them to all the processes in a `torch.distributed` group. Args: - value (Union[torch.Tensor, float]): New value to update the moving average + metrics (Sequence[str]): the metrics. + process_group (Optional[ProcessGroup], optional): the distributed process group. + Defaults to None. """ - if isinstance(value, torch.Tensor): - value = value.item() - if len(self._values) == self.window_size: - self._sum -= self._values.popleft() - self._sum += value - self._values.append(value) + super().__init__() + self._aggregator = metrics + if isinstance(metrics, dict): + self._aggregator = MetricAggregator(metrics) + for m in self._aggregator.metrics.values(): + m._to_sync = False + m.sync_on_compute = False + self._process_group = process_group if process_group is not None else torch.distributed.group.WORLD + self._distributed_available = _distributed_available() + self._world_size = dist.get_world_size(self._process_group) if self._distributed_available else 1 + + def update(self, name: str, value: Union[float, Tensor]) -> None: + self._aggregator.update(name, value) - def compute(self) -> Dict: - """Computes the moving average. + @torch.no_grad() + def compute(self) -> List[Dict[str, Tensor]]: + """Compute the means, one for every metric. The metrics are first broadcasted Returns: - Dict: Dictionary with the moving average + List[Dict[str, List]]: the computed metrics, broadcasted from and to every processes. + The list of the data returned is equal to the number of processes in the process group. + """ + computed_metrics = self._aggregator.compute() + if not self._distributed_available: + return [computed_metrics] + gathered_data = [None for _ in range(self._world_size)] + dist.all_gather_object(gathered_data, computed_metrics, group=self._process_group) + return gathered_data + + def to(self, device: Union[str, torch.device] = "cpu") -> "RankIndependentMetricAggregator": + """Move all metrics to the given device + + Args: + device (Union[str, torch.device], optional): Device to move the metrics to. Defaults to "cpu". """ - if len(self._values) == 0: - return None - average = self._sum / len(self._values) - std = torch.std(torch.tensor(self._values, device=self._device)) - torch.max(torch.tensor(self._values, device=self._device)) - torch.min(torch.tensor(self._values, device=self._device)) - return average, std.item() + self._aggregator.to(device) + return self def reset(self) -> None: - """Resets the moving average.""" - self._values.clear() - self._sum = torch.tensor(0.0, device=self._device) + """Reset the internal state of the metrics""" + self._aggregator.reset() diff --git a/sheeprl/utils/timer.py b/sheeprl/utils/timer.py new file mode 100644 index 00000000..dcbd3d6a --- /dev/null +++ b/sheeprl/utils/timer.py @@ -0,0 +1,82 @@ +# timer.py + +import time +from contextlib import ContextDecorator +from typing import Dict, Optional, Union + +import torch +from torchmetrics import Metric, SumMetric + + +class TimerError(Exception): + """A custom exception used to report errors in use of timer class""" + + +class timer(ContextDecorator): + """A timer class to measure the time of a code block.""" + + disabled: bool = False + timers: Dict[str, Metric] = {} + _start_time: Optional[float] = None + + def __init__(self, name: str, metric: Optional[Metric] = None) -> None: + """Add timer to dict of timers after initialization""" + self.name = name + if not timer.disabled and self.name is not None and self.name not in self.timers.keys(): + self.timers.setdefault(self.name, metric if metric is not None else SumMetric()) + + def start(self) -> None: + """Start a new timer""" + if self._start_time is not None: + raise TimerError("timer is running. Use .stop() to stop it") + + self._start_time = time.perf_counter() + + def stop(self) -> float: + """Stop the timer, and report the elapsed time""" + if self._start_time is None: + raise TimerError("timer is not running. Use .start() to start it") + + # Calculate elapsed time + elapsed_time = time.perf_counter() - self._start_time + self._start_time = None + + # Report elapsed time + if self.name: + self.timers[self.name].update(elapsed_time) + + return elapsed_time + + @classmethod + def to(cls, device: Union[str, torch.device] = "cpu") -> None: + """Create a new timer on a different device""" + if cls.timers: + for k, v in cls.timers.items(): + cls.timers[k] = v.to(device) + + @classmethod + def reset(cls) -> None: + """Reset all timers""" + for timer in cls.timers.values(): + timer.reset() + cls._start_time = None + + @classmethod + def compute(cls) -> Dict[str, torch.Tensor]: + """Reduce the timers to a single value""" + reduced_timers = {} + if cls.timers: + for k, v in cls.timers.items(): + reduced_timers[k] = v.compute().item() + return reduced_timers + + def __enter__(self): + """Start a new timer as a context manager""" + if not timer.disabled: + self.start() + return self + + def __exit__(self, *exc_info): + """Stop the context manager timer""" + if not timer.disabled: + self.stop() diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index 24bb1033..aeb72d04 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -1,7 +1,12 @@ -from typing import Optional, Tuple +import os +from typing import Optional, Sequence, Tuple, Union +import rich.syntax +import rich.tree import torch import torch.nn as nn +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.utilities import rank_zero_only from torch import Tensor @@ -131,3 +136,35 @@ def symlog(x: Tensor) -> Tensor: def symexp(x: Tensor) -> Tensor: return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + + +@rank_zero_only +def print_config( + config: DictConfig, + fields: Sequence[str] = ("algo", "buffer", "checkpoint", "env", "exp", "hydra", "metric", "optim"), + resolve: bool = True, + cfg_save_path: Optional[Union[str, os.PathLike]] = None, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + config: Configuration composed by Hydra. + fields: Determines which main fields from config will + be printed and in what order. + resolve: Whether to resolve reference fields of DictConfig. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, DictConfig): + branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + rich.print(tree) + if cfg_save_path is not None: + with open(os.path.join(os.getcwd(), "config_tree.txt"), "w") as fp: + rich.print(tree, file=fp)