Skip to content

Commit

Permalink
Do not compute policy_steps_per_update every time
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed May 5, 2024
1 parent 51d6771 commit f09519a
Show file tree
Hide file tree
Showing 15 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):

for update in range(start_step, num_updates + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down
2 changes: 1 addition & 1 deletion howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):

for update in range(start_step, num_updates + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for update in range(1, num_updates + 1):
with torch.inference_mode():
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for update in range(start_step, num_updates + 1):
with torch.inference_mode():
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for update in range(start_step, num_updates + 1):
with torch.inference_mode():
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
per_rank_gradient_steps = 0
cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size
policy_step += policy_steps_per_update

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down

0 comments on commit f09519a

Please sign in to comment.