diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 1938354c..1eea1737 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -581,7 +581,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: real_actions = actions = np.array(envs.action_space.sample()) @@ -692,7 +692,7 @@ def main(cfg: DictConfig): aggregator, cfg, ) - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 026307f6..0d7217eb 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -623,7 +623,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -770,7 +770,7 @@ def main(cfg: DictConfig): actions_dim, ) per_rank_gradient_steps += 1 - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index f3c85e3f..e1abe5c4 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -566,7 +566,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -708,7 +708,7 @@ def main(cfg: DictConfig): moments, ) per_rank_gradient_steps += 1 - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 3ad3e221..fdd263cd 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -264,7 +264,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): actions, _ = actor.module(obs) actions = actions.cpu().numpy() @@ -308,7 +308,7 @@ def main(cfg: DictConfig): # Train the agent if update > learning_starts: train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg) - train_step += 1 + train_step += world_size # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index 2ab82df0..f987d558 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -650,7 +650,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: real_actions = actions = np.array(envs.action_space.sample()) @@ -771,7 +771,7 @@ def main(cfg: DictConfig): critic_exploration_optimizer=critic_exploration_optimizer, is_exploring=is_exploring, ) - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index e8a76b61..d1d48a0b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -802,7 +802,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: real_actions = actions = np.array(envs.action_space.sample()) @@ -958,7 +958,7 @@ def main(cfg: DictConfig): actions_dim=actions_dim, is_exploring=is_exploring, ) - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 4559479e..632e8ba3 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -221,6 +221,7 @@ def main(cfg: DictConfig): # Global variables last_log = 0 last_train = 0 + train_step = 0 policy_step = 0 last_checkpoint = 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) @@ -268,7 +269,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { @@ -358,6 +359,7 @@ def main(cfg: DictConfig): with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): train(fabric, agent, optimizer, gathered_data, aggregator, cfg) + train_step += world_size if cfg.algo.anneal_lr: fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step) @@ -388,7 +390,7 @@ def main(cfg: DictConfig): timer_metrics = timer.compute() fabric.log( "Time/sps_train", - (update - last_train) / timer_metrics["Time/train_time"], + (train_step - last_train) / timer_metrics["Time/train_time"], policy_step, ) fabric.log( @@ -400,8 +402,8 @@ def main(cfg: DictConfig): timer.reset() # Reset counters - last_train = update last_log = policy_step + last_train = train_step # Checkpoint model if ( diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index f8418161..862ce9dc 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -349,6 +349,7 @@ def trainer( optimization_pg: CollectibleGroup, ): global_rank = world_collective.rank + group_world_size = world_collective.world_size - 1 # Receive (possibly updated, by the make_dict_env method for example) cfg from the player data = [None] @@ -406,6 +407,7 @@ def trainer( update = 0 last_log = 0 last_train = 0 + train_step = 0 policy_step = 0 last_checkpoint = 0 initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) @@ -428,6 +430,7 @@ def trainer( return data = make_tensordict(data, device=device) update += 1 + train_step += group_world_size policy_step += cfg.env.num_envs * cfg.algo.rollout_steps # Prepare sampler @@ -504,7 +507,7 @@ def trainer( # Sync distributed timers timers = timer.compute() - metrics.update({"Time/sps_train": (update - last_train) / timers["Time/train_time"]}) + metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]}) timer.reset() # Send metrics to the player @@ -520,8 +523,8 @@ def trainer( ) # Broadcast metrics: fake send with object list between rank-0 and rank-1 # Reset counters - last_train = update last_log = policy_step + last_train = train_step if cfg.algo.anneal_lr: scheduler.step() diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 352e3fa4..a37fd05b 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -198,6 +198,7 @@ def main(cfg: DictConfig): # Global variables last_log = 0 last_train = 0 + train_step = 0 policy_step = 0 last_checkpoint = 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) @@ -237,7 +238,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment action_logits, values, state = agent.module(next_obs, state=next_state) @@ -335,6 +336,7 @@ def main(cfg: DictConfig): with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) + train_step += world_size if cfg.algo.anneal_lr: fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step) @@ -365,7 +367,7 @@ def main(cfg: DictConfig): timer_metrics = timer.compute() fabric.log( "Time/sps_train", - (update - last_train) / timer_metrics["Time/train_time"], + (train_step - last_train) / timer_metrics["Time/train_time"], policy_step, ) fabric.log( @@ -377,8 +379,8 @@ def main(cfg: DictConfig): timer.reset() # Reset counters - last_train = update last_log = policy_step + last_train = train_step # Checkpoint model if ( diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 98ace69c..c7f10aea 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -212,7 +212,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): if update <= learning_starts: actions = envs.action_space.sample() else: @@ -300,7 +300,7 @@ def main(cfg: DictConfig): cfg, policy_steps_per_update, ) - train_step += 1 + train_step += world_size # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 2aeb4520..b4e21d29 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -296,7 +296,7 @@ def trainer( optimization_pg: CollectibleGroup, ): global_rank = world_collective.rank - global_rank - 1 + group_world_size = world_collective.world_size - 1 # Receive (possibly updated, by the make_dict_env method for example) cfg from the player data = [None] @@ -412,7 +412,7 @@ def trainer( policy_steps_per_update, group=optimization_pg, ) - train_step += 1 + train_step += group_world_size if global_rank == 1: player_trainer_collective.broadcast( diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 1333dea7..88a6cba2 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -393,7 +393,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): if update < learning_starts: actions = envs.action_space.sample() else: @@ -497,7 +497,7 @@ def main(cfg: DictConfig): cfg, policy_steps_per_update, ) - train_step += 1 + train_step += world_size # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: