Skip to content

Commit

Permalink
fix: resume from checkpoint log dir
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi committed Sep 15, 2023
1 parent b8dd0df commit dbc1345
Show file tree
Hide file tree
Showing 12 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // (world_collective.world_size - 1)
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Initialize logger
root_dir = (
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Initialize logger
root_dir = (
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main(cfg: DictConfig):
cfg.checkpoint.resume_from = str(ckpt_path)
cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size
cfg.root_dir = root_dir
cfg.run_name = f"resume_from_checkpoint_{run_name}"
cfg.run_name = run_name

# Create TensorBoardLogger. This will create the logger only on the
# rank-0 process
Expand Down

0 comments on commit dbc1345

Please sign in to comment.