From dbc13458ee71e231ede687af20f267900af69927 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 15 Sep 2023 15:00:01 +0200 Subject: [PATCH] fix: resume from checkpoint log dir --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 2 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 2 +- sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- sheeprl/algos/droq/droq.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2.py | 2 +- sheeprl/algos/ppo/ppo.py | 2 +- sheeprl/algos/ppo/ppo_decoupled.py | 2 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 2 +- sheeprl/algos/sac/sac.py | 2 +- sheeprl/algos/sac/sac_decoupled.py | 2 +- sheeprl/algos/sac_ae/sac_ae.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 1eea1737..7ea3e16f 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -384,7 +384,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 0d7217eb..2ce31283 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -404,7 +404,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index e1abe5c4..75834ff1 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -358,7 +358,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 5cffdec4..b2917e25 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -150,7 +150,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index f987d558..eb301f40 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -389,7 +389,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index d1d48a0b..c8e10803 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -491,7 +491,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 96b721d6..516b1f84 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -142,7 +142,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 889da385..8d57cc22 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -56,7 +56,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // (world_collective.world_size - 1) cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Initialize logger root_dir = ( diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 338900af..0dda0853 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -137,7 +137,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index f17569bb..4f40a338 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -105,7 +105,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 8ff1097d..7140a73c 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -56,7 +56,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Initialize logger root_dir = ( diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 46dc3784..4e7c5bcf 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -185,7 +185,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process