Skip to content

Commit

Permalink
Feature/from update to iter (#284)
Browse files Browse the repository at this point in the history
* From upate to iter

* Fix missing iter_num key

* Change start_step into start_iter
  • Loading branch information
belerico authored May 10, 2024
1 parent 75d752f commit 4003506
Show file tree
Hide file tree
Showing 19 changed files with 409 additions and 408 deletions.
32 changes: 16 additions & 16 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -550,36 +550,36 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
# Global variables
last_train = 0
train_step = 0
start_step = (
start_iter = (
# + 1 because the checkpoint is at the end of the update step
# (when resuming from a checkpoint, the update at the checkpoint
# is ended and you have to start with the next one)
(state["update"] // fabric.world_size) + 1
(state["iter_num"] // fabric.world_size) + 1
if cfg.checkpoint.resume_from
else 1
)
policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size

# Warning for log and checkpoint every
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
if cfg.checkpoint.every % policy_steps_per_iter != 0:
warnings.warn(
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)

# Get the first environment observation and start the optimization
Expand All @@ -590,9 +590,9 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:])
step_data[k] = next_obs[k][np.newaxis]

for update in range(start_step, num_updates + 1):
for iter_num in range(start_iter, total_iters + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += policy_steps_per_update
policy_step += policy_steps_per_iter

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down Expand Up @@ -653,7 +653,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
train(fabric, agent, optimizer, local_data, aggregator, cfg)

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run:
# Sync distributed metrics
if aggregator and not aggregator.disabled:
metrics_dict = aggregator.compute()
Expand Down Expand Up @@ -686,13 +686,13 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
or cfg.dry_run
or update == num_updates
or iter_num == total_iters
):
last_checkpoint = policy_step
state = {
"agent": agent.state_dict(),
"optimizer": optimizer.state_dict(),
"update_step": update,
"iter_num": iter_num * world_size,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)
Expand Down
32 changes: 16 additions & 16 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -548,36 +548,36 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
# Global variables
last_train = 0
train_step = 0
start_step = (
start_iter = (
# + 1 because the checkpoint is at the end of the update step
# (when resuming from a checkpoint, the update at the checkpoint
# is ended and you have to start with the next one)
(state["update"] // fabric.world_size) + 1
(state["iter_num"] // fabric.world_size) + 1
if cfg.checkpoint.resume_from
else 1
)
policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size

# Warning for log and checkpoint every
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
if cfg.checkpoint.every % policy_steps_per_iter != 0:
warnings.warn(
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)

# Get the first environment observation and start the optimization
Expand All @@ -588,9 +588,9 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:])
step_data[k] = next_obs[k][np.newaxis]

for update in range(start_step, num_updates + 1):
for iter_num in range(start_iter, total_iters + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += policy_steps_per_update
policy_step += policy_steps_per_iter

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down Expand Up @@ -651,7 +651,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
train(fabric, agent, optimizer, local_data, aggregator, cfg)

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run:
# Sync distributed metrics
if aggregator and not aggregator.disabled:
metrics_dict = aggregator.compute()
Expand Down Expand Up @@ -684,13 +684,13 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
or cfg.dry_run
or update == num_updates
or iter_num == total_iters
):
last_checkpoint = policy_step
state = {
"agent": agent.state_dict(),
"optimizer": optimizer.state_dict(),
"update_step": update,
"iter_num": iter_num * world_size,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)
Expand Down
26 changes: 13 additions & 13 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,23 +200,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
train_step = 0
policy_step = 0
last_checkpoint = 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1

# Warning for log and checkpoint every
if cfg.metric.log_every % policy_steps_per_update != 0:
if cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
if cfg.checkpoint.every % policy_steps_per_iter != 0:
warnings.warn(
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)

# Get the first environment observation and start the optimization
Expand All @@ -225,10 +225,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for k in obs_keys:
step_data[k] = next_obs[k][np.newaxis]

for update in range(1, num_updates + 1):
for iter_num in range(1, total_iters + 1):
with torch.inference_mode():
for _ in range(0, cfg.algo.rollout_steps):
policy_step += policy_steps_per_update
policy_step += policy_steps_per_iter

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down Expand Up @@ -325,7 +325,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
train(fabric, agent, optimizer, local_data, aggregator, cfg)

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run:
# Sync distributed metrics
if aggregator and not aggregator.disabled:
metrics_dict = aggregator.compute()
Expand Down Expand Up @@ -358,13 +358,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
or cfg.dry_run
or (update == num_updates and cfg.checkpoint.save_last)
or (iter_num == total_iters and cfg.checkpoint.save_last)
):
last_checkpoint = policy_step
state = {
"agent": agent.state_dict(),
"optimizer": optimizer.state_dict(),
"update_step": update,
"iter_num": iter_num * world_size,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)
Expand Down
44 changes: 22 additions & 22 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,44 +496,44 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Global variables
train_step = 0
last_train = 0
start_step = (
start_iter = (
# + 1 because the checkpoint is at the end of the update step
# (when resuming from a checkpoint, the update at the checkpoint
# is ended and you have to start with the next one)
(state["update"] // world_size) + 1
(state["iter_num"] // world_size) + 1
if cfg.checkpoint.resume_from
else 1
)
policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0
policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * world_size)
num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0
prefill_steps = learning_starts + start_step
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_step
learning_starts += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
if cfg.checkpoint.resume_from:
ratio.load_state_dict(state["ratio"])

# Warning for log and checkpoint every
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
if cfg.checkpoint.every % policy_steps_per_iter != 0:
warnings.warn(
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)

# Get the first environment observation and start the optimization
Expand All @@ -551,16 +551,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
player.init_states()

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += policy_steps_per_update
for iter_num in range(start_iter, total_iters + 1):
policy_step += policy_steps_per_iter

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
iter_num <= learning_starts
and cfg.checkpoint.resume_from is None
and "minedojo" not in cfg.env.wrapper._target_.lower()
):
Expand Down Expand Up @@ -643,8 +643,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
player.init_states(reset_envs=dones_idxes)

# Train the agent
if update >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_update
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
Expand Down Expand Up @@ -676,7 +676,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step))

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates):
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
# Sync distributed metrics
if aggregator and not aggregator.disabled:
metrics_dict = aggregator.compute()
Expand Down Expand Up @@ -712,7 +712,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Checkpoint Model
if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or (
update == num_updates and cfg.checkpoint.save_last
iter_num == total_iters and cfg.checkpoint.save_last
):
last_checkpoint = policy_step
state = {
Expand All @@ -723,7 +723,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
"actor_optimizer": actor_optimizer.state_dict(),
"critic_optimizer": critic_optimizer.state_dict(),
"ratio": ratio.state_dict(),
"update": update * world_size,
"iter_num": iter_num * world_size,
"batch_size": cfg.algo.per_rank_batch_size * world_size,
"last_log": last_log,
"last_checkpoint": last_checkpoint,
Expand Down
Loading

0 comments on commit 4003506

Please sign in to comment.