Skip to content

Commit

Permalink
Fix sps: sps_train is computed globally while sps_env is computed loc…
Browse files Browse the repository at this point in the history
…ally
  • Loading branch information
belerico committed Sep 15, 2023
1 parent 5a50c5b commit 7ab2787
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 26 deletions.
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
# Sample an action given the observation received by the environment
if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id:
real_actions = actions = np.array(envs.action_space.sample())
Expand Down Expand Up @@ -692,7 +692,7 @@ def main(cfg: DictConfig):
aggregator,
cfg,
)
train_step += 1
train_step += world_size
updates_before_training = cfg.algo.train_every // policy_steps_per_update
if cfg.algo.player.expl_decay:
expl_decay_steps += 1
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -770,7 +770,7 @@ def main(cfg: DictConfig):
actions_dim,
)
per_rank_gradient_steps += 1
train_step += 1
train_step += world_size
updates_before_training = cfg.algo.train_every // policy_steps_per_update
if cfg.algo.player.expl_decay:
expl_decay_steps += 1
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -708,7 +708,7 @@ def main(cfg: DictConfig):
moments,
)
per_rank_gradient_steps += 1
train_step += 1
train_step += world_size
updates_before_training = cfg.algo.train_every // policy_steps_per_update
if cfg.algo.player.expl_decay:
expl_decay_steps += 1
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with torch.no_grad():
actions, _ = actor.module(obs)
actions = actions.cpu().numpy()
Expand Down Expand Up @@ -308,7 +308,7 @@ def main(cfg: DictConfig):
# Train the agent
if update > learning_starts:
train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg)
train_step += 1
train_step += world_size

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
# Sample an action given the observation received by the environment
if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id:
real_actions = actions = np.array(envs.action_space.sample())
Expand Down Expand Up @@ -771,7 +771,7 @@ def main(cfg: DictConfig):
critic_exploration_optimizer=critic_exploration_optimizer,
is_exploring=is_exploring,
)
train_step += 1
train_step += world_size
updates_before_training = cfg.algo.train_every // policy_steps_per_update
if cfg.algo.player.expl_decay:
expl_decay_steps += 1
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
# Sample an action given the observation received by the environment
if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id:
real_actions = actions = np.array(envs.action_space.sample())
Expand Down Expand Up @@ -958,7 +958,7 @@ def main(cfg: DictConfig):
actions_dim=actions_dim,
is_exploring=is_exploring,
)
train_step += 1
train_step += world_size
updates_before_training = cfg.algo.train_every // policy_steps_per_update
if cfg.algo.player.expl_decay:
expl_decay_steps += 1
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def main(cfg: DictConfig):
# Global variables
last_log = 0
last_train = 0
train_step = 0
policy_step = 0
last_checkpoint = 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
Expand Down Expand Up @@ -268,7 +269,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = {
Expand Down Expand Up @@ -358,6 +359,7 @@ def main(cfg: DictConfig):

with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
train(fabric, agent, optimizer, gathered_data, aggregator, cfg)
train_step += world_size

if cfg.algo.anneal_lr:
fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step)
Expand Down Expand Up @@ -388,7 +390,7 @@ def main(cfg: DictConfig):
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(update - last_train) / timer_metrics["Time/train_time"],
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
Expand All @@ -400,8 +402,8 @@ def main(cfg: DictConfig):
timer.reset()

# Reset counters
last_train = update
last_log = policy_step
last_train = train_step

# Checkpoint model
if (
Expand Down
7 changes: 5 additions & 2 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def trainer(
optimization_pg: CollectibleGroup,
):
global_rank = world_collective.rank
group_world_size = world_collective.world_size - 1

# Receive (possibly updated, by the make_dict_env method for example) cfg from the player
data = [None]
Expand Down Expand Up @@ -406,6 +407,7 @@ def trainer(
update = 0
last_log = 0
last_train = 0
train_step = 0
policy_step = 0
last_checkpoint = 0
initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef)
Expand All @@ -428,6 +430,7 @@ def trainer(
return
data = make_tensordict(data, device=device)
update += 1
train_step += group_world_size
policy_step += cfg.env.num_envs * cfg.algo.rollout_steps

# Prepare sampler
Expand Down Expand Up @@ -504,7 +507,7 @@ def trainer(

# Sync distributed timers
timers = timer.compute()
metrics.update({"Time/sps_train": (update - last_train) / timers["Time/train_time"]})
metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]})
timer.reset()

# Send metrics to the player
Expand All @@ -520,8 +523,8 @@ def trainer(
) # Broadcast metrics: fake send with object list between rank-0 and rank-1

# Reset counters
last_train = update
last_log = policy_step
last_train = train_step

if cfg.algo.anneal_lr:
scheduler.step()
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def main(cfg: DictConfig):
# Global variables
last_log = 0
last_train = 0
train_step = 0
policy_step = 0
last_checkpoint = 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
Expand Down Expand Up @@ -237,7 +238,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with torch.no_grad():
# Sample an action given the observation received by the environment
action_logits, values, state = agent.module(next_obs, state=next_state)
Expand Down Expand Up @@ -335,6 +336,7 @@ def main(cfg: DictConfig):

with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
train(fabric, agent, optimizer, padded_sequences, aggregator, cfg)
train_step += world_size

if cfg.algo.anneal_lr:
fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step)
Expand Down Expand Up @@ -365,7 +367,7 @@ def main(cfg: DictConfig):
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(update - last_train) / timer_metrics["Time/train_time"],
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
Expand All @@ -377,8 +379,8 @@ def main(cfg: DictConfig):
timer.reset()

# Reset counters
last_train = update
last_log = policy_step
last_train = train_step

# Checkpoint model
if (
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
if update <= learning_starts:
actions = envs.action_space.sample()
else:
Expand Down Expand Up @@ -300,7 +300,7 @@ def main(cfg: DictConfig):
cfg,
policy_steps_per_update,
)
train_step += 1
train_step += world_size

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def trainer(
optimization_pg: CollectibleGroup,
):
global_rank = world_collective.rank
global_rank - 1
group_world_size = world_collective.world_size - 1

# Receive (possibly updated, by the make_dict_env method for example) cfg from the player
data = [None]
Expand Down Expand Up @@ -412,7 +412,7 @@ def trainer(
policy_steps_per_update,
group=optimization_pg,
)
train_step += 1
train_step += group_world_size

if global_rank == 1:
player_trainer_collective.broadcast(
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def main(cfg: DictConfig):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
if update < learning_starts:
actions = envs.action_space.sample()
else:
Expand Down Expand Up @@ -497,7 +497,7 @@ def main(cfg: DictConfig):
cfg,
policy_steps_per_update,
)
train_step += 1
train_step += world_size

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
Expand Down

0 comments on commit 7ab2787

Please sign in to comment.