Skip to content

Commit

Permalink
Adjust prefill_steps once before training
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed May 13, 2024
1 parent 2c75fb8 commit 2a05099
Show file tree
Hide file tree
Showing 13 changed files with 25 additions and 25 deletions.
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
Expand Down Expand Up @@ -645,7 +645,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
Expand Down Expand Up @@ -673,7 +673,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size
learning_starts += start_iter
Expand Down Expand Up @@ -655,7 +655,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size
learning_starts += start_iter
Expand Down Expand Up @@ -347,7 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
train(
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
Expand Down Expand Up @@ -669,7 +669,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if resume_from_checkpoint:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
Expand Down Expand Up @@ -324,7 +324,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
if player.actor_type != "task":
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
Expand Down Expand Up @@ -809,7 +809,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if resume_from_checkpoint:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
Expand Down Expand Up @@ -347,7 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
if player.actor_type != "task":
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
Expand Down Expand Up @@ -896,7 +896,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
local_data = rb.sample_tensors(
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if resume_from_checkpoint:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_iter
Expand Down Expand Up @@ -344,7 +344,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
if player.actor_type != "task":
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size
learning_starts += start_iter
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def player(
policy_steps_per_iter = int(cfg.env.num_envs)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint:
learning_starts += start_iter
prefill_steps += start_iter
Expand Down Expand Up @@ -228,7 +228,7 @@ def player(

# Send data to the training agents
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / (fabric.world_size - 1))
cumulative_per_rank_gradient_steps += per_rank_gradient_steps
if per_rank_gradient_steps > 0:
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0
prefill_steps = learning_starts
prefill_steps = learning_starts - int(learning_starts > 0)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size
learning_starts += start_iter
Expand Down Expand Up @@ -375,7 +375,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Train the agent
if iter_num >= learning_starts:
ratio_steps = policy_step - (prefill_steps - 1) * policy_steps_per_iter
ratio_steps = policy_step - prefill_steps * policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
# We sample one time to reduce the communications between processes
Expand Down

0 comments on commit 2a05099

Please sign in to comment.