Skip to content

Commit

Permalink
Fix compute prefill_steps as policy_steps (#287)
Browse files Browse the repository at this point in the history
* Fix compute prefill_steps as policy_steps

* Adjust prefill_steps once before training
  • Loading branch information
belerico authored May 14, 2024
1 parent 419c7ce commit cafb494
Show file tree
Hide file tree
Showing 13 changed files with 38 additions and 25 deletions.
5 changes: 3 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -644,7 +645,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -679,7 +680,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -657,7 +658,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -346,7 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
train(
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -668,7 +669,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if resume_from_checkpoint:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -323,7 +324,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
if player.actor_type != "task":
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,10 +673,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -812,7 +813,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if resume_from_checkpoint:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -346,7 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
if player.actor_type != "task":
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,10 +746,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -895,7 +896,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if resume_from_checkpoint:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -343,7 +344,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
if player.actor_type != "task":
Expand Down
3 changes: 2 additions & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ def player(
policy_steps_per_iter = int(cfg.env.num_envs)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint:
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -227,7 +228,7 @@ def player(

# Send data to the training agents
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / (fabric.world_size - 1))
cumulative_per_rank_gradient_steps += per_rank_gradient_steps
if per_rank_gradient_steps > 0:
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size
learning_starts += start_iter
prefill_steps += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
Expand Down Expand Up @@ -374,7 +375,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
# We sample one time to reduce the communications between processes
Expand Down

0 comments on commit cafb494

Please sign in to comment.