From eed15390e2418baa0c5be615e12a859ac1d1bddf Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 13 Sep 2023 11:22:20 +0200 Subject: [PATCH 1/8] feat: added resume from checkpoint for sac, sac_decoupled, droq, ppo, ppo_decoupled, ppo_recurrent --- sheeprl/algos/droq/agent.py | 31 ++-- sheeprl/algos/droq/droq.py | 82 +++++++--- sheeprl/algos/ppo/ppo.py | 60 +++++-- sheeprl/algos/ppo/ppo_decoupled.py | 122 ++++++++------ sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 84 ++++++---- sheeprl/algos/sac/agent.py | 29 +++- sheeprl/algos/sac/sac.py | 110 ++++++++----- sheeprl/algos/sac/sac_decoupled.py | 157 ++++++++++++------- sheeprl/configs/env_config.yaml | 4 + 9 files changed, 451 insertions(+), 228 deletions(-) diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py index e7db31ca..cfb817a0 100644 --- a/sheeprl/algos/droq/agent.py +++ b/sheeprl/algos/droq/agent.py @@ -83,7 +83,22 @@ def __init__( # Actor and critics self._num_critics = len(critics) - self._actor = actor + self.actor = actor + self.critics = critics + + # Automatic entropy tuning + self._target_entropy = torch.tensor(target_entropy, device=device) + self._log_alpha = torch.nn.Parameter(torch.log(torch.tensor([alpha], device=device)), requires_grad=True) + + # EMA tau + self._tau = tau + + @property + def critics(self) -> nn.ModuleList: + return self.qfs + + @critics.setter + def critics(self, critics: Sequence[Union[DROQCritic, _FabricModule]]) -> None: self._qfs = nn.ModuleList(critics) # Create target critic unwrapping the DDP module from the critics to prevent @@ -93,7 +108,7 @@ def __init__( # This happens when we're using the decoupled version of SAC for example qfs_unwrapped_modules = [] for critic in critics: - if getattr(critic, "module"): + if hasattr(critic, "module"): critic_module = critic.module else: critic_module = critic @@ -103,13 +118,6 @@ def __init__( for p in self._qfs_target.parameters(): p.requires_grad = False - # Automatic entropy tuning - self._target_entropy = torch.tensor(target_entropy, device=device) - self._log_alpha = torch.nn.Parameter(torch.log(torch.tensor([alpha], device=device)), requires_grad=True) - - # EMA tau - self._tau = tau - @property def num_critics(self) -> int: return self._num_critics @@ -126,6 +134,11 @@ def qfs_unwrapped(self) -> nn.ModuleList: def actor(self) -> Union[SACActor, _FabricModule]: return self._actor + @actor.setter + def actor(self, actor: Union[SACActor, _FabricModule]) -> None: + self._actor = actor + return + @property def qfs_target(self) -> nn.ModuleList: return self._qfs_target diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 2eb91434..b0b24e89 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -1,4 +1,5 @@ import os +import pathlib import time import warnings from math import prod @@ -127,6 +128,18 @@ def main(cfg: DictConfig): fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic + # Resume from checkpoint + if cfg.checkpoint.resume_from: + root_dir = cfg.root_dir + run_name = cfg.run_name + state = fabric.load(cfg.checkpoint.resume_from) + ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) + cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + cfg.checkpoint.resume_from = str(ckpt_path) + cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.root_dir = root_dir + cfg.run_name = f"resume_from_checkpoint_{run_name}" + # Create TensorBoardLogger. This will create the logger only on the # rank-0 process logger, log_dir = create_tensorboard_logger(fabric, cfg, "droq") @@ -163,23 +176,19 @@ def main(cfg: DictConfig): # Define the agent and the optimizer and setup them with Fabric act_dim = prod(envs.single_action_space.shape) obs_dim = prod(envs.single_observation_space.shape) - actor = fabric.setup_module( - SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - hidden_size=cfg.algo.actor.hidden_size, - action_low=envs.single_action_space.low, - action_high=envs.single_action_space.high, - ) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + hidden_size=cfg.algo.actor.hidden_size, + action_low=envs.single_action_space.low, + action_high=envs.single_action_space.high, ) critics = [ - fabric.setup_module( - DROQCritic( - observation_dim=obs_dim + act_dim, - hidden_size=cfg.algo.critic.hidden_size, - num_critics=1, - dropout=cfg.algo.critic.dropout, - ) + DROQCritic( + observation_dim=obs_dim + act_dim, + hidden_size=cfg.algo.critic.hidden_size, + num_critics=1, + dropout=cfg.algo.critic.dropout, ) for _ in range(cfg.algo.critic.n) ] @@ -187,12 +196,21 @@ def main(cfg: DictConfig): agent = DROQAgent( actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device ) + if cfg.checkpoint.resume_from: + agent.load_state_dict(state["agent"]) + agent.actor = fabric.setup_module(agent.actor) + agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers + qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) + actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters()) + alpha_optimizer = hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]) + if cfg.checkpoint.resume_from: + qf_optimizer.load_state_dict(state["qf_optimizer"]) + actor_optimizer.load_state_dict(state["actor_optimizer"]) + alpha_optimizer.load_state_dict(state["alpha_optimizer"]) qf_optimizer, actor_optimizer, alpha_optimizer = fabric.setup_optimizers( - hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()), - hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters()), - hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]), + qf_optimizer, actor_optimizer, alpha_optimizer ) # Metrics @@ -217,16 +235,26 @@ def main(cfg: DictConfig): memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) + if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: + if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + rb = state["rb"][fabric.global_rank] + elif isinstance(state["rb"], ReplayBuffer): + rb = state["rb"] + else: + raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 - last_log = 0 - last_checkpoint = 0 + start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 + last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 + if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: + learning_starts += start_step # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: @@ -248,10 +276,10 @@ def main(cfg: DictConfig): # Get the first environment observation and start the optimization obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs] - for update in range(1, num_updates + 1): + for update in range(start_step, num_updates + 1): # Sample an action given the observation received by the environment with torch.no_grad(): - actions, _ = actor.module(obs) + actions, _ = agent.actor.module(obs) actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions) dones = np.logical_or(dones, truncated) @@ -307,12 +335,16 @@ def main(cfg: DictConfig): or cfg.dry_run or update == num_updates ): + last_checkpoint = policy_step state = { "agent": agent.state_dict(), "qf_optimizer": qf_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), - "update": update, + "update": update * fabric.world_size, + "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "last_log": last_log, + "last_checkpoint": last_checkpoint, } ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") fabric.call( @@ -335,7 +367,7 @@ def main(cfg: DictConfig): mask_velocities=False, vector_env_idx=0, )() - test(actor.module, test_env, fabric, cfg) + test(agent.actor.module, test_env, fabric, cfg) if __name__ == "__main__": diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index fcf4bd0f..c65cd89e 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -1,5 +1,6 @@ import copy import os +import pathlib import time import warnings from typing import Union @@ -129,6 +130,18 @@ def main(cfg: DictConfig): fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic + # Resume from checkpoint + if cfg.checkpoint.resume_from: + root_dir = cfg.root_dir + run_name = cfg.run_name + state = fabric.load(cfg.checkpoint.resume_from) + ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) + cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + cfg.checkpoint.resume_from = str(ckpt_path) + cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.root_dir = root_dir + cfg.run_name = f"resume_from_checkpoint_{run_name}" + # Create TensorBoardLogger. This will create the logger only on the # rank-0 process logger, log_dir = create_tensorboard_logger(fabric, cfg, "ppo") @@ -184,23 +197,30 @@ def main(cfg: DictConfig): is_continuous=is_continuous, ) - # Define the agent and the optimizer and setup them with Fabric + # Define the optimizer optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters()) + + # Load the state from the checkpoint + if cfg.checkpoint.resume_from: + agent.load_state_dict(state["agent"]) + optimizer.load_state_dict(state["optimizer"]) + + # Setup agent and optimizer with Fabric agent = fabric.setup_module(agent) optimizer = fabric.setup_optimizers(optimizer) # Create a metric aggregator to log the metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ) + aggregator.to(device) # Local data if cfg.buffer.size < cfg.algo.rollout_steps: @@ -219,9 +239,10 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - last_log = 0 - policy_step = 0 - last_checkpoint = 0 + start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 + last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 @@ -247,6 +268,8 @@ def main(cfg: DictConfig): from torch.optim.lr_scheduler import PolynomialLR scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) + if cfg.checkpoint.resume_from: + scheduler.load_state_dict(state["scheduler"]) # Get the first environment observation and start the optimization o = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] @@ -262,7 +285,7 @@ def main(cfg: DictConfig): next_obs[k] = torch_obs next_done = torch.zeros(cfg.env.num_envs, 1, dtype=torch.float32).to(fabric.device) # [N_envs, 1] - for update in range(1, num_updates + 1): + for update in range(start_step, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs * world_size @@ -391,8 +414,11 @@ def main(cfg: DictConfig): state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), - "update_step": update, "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "last_log": last_log, + "last_checkpoint": last_checkpoint, } ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 4b66a7d5..430a7202 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -1,5 +1,6 @@ import copy import os +import pathlib import time import warnings from datetime import datetime, timedelta @@ -34,6 +35,26 @@ @torch.no_grad() def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective): + # Initialize Fabric object + fabric = Fabric(callbacks=[CheckpointCallback()]) + if not _is_using_cli(): + fabric.launch() + device = fabric.device + fabric.seed_everything(cfg.seed) + torch.backends.cudnn.deterministic = cfg.torch_deterministic + + # Resume from checkpoint + if cfg.checkpoint.resume_from: + root_dir = cfg.root_dir + run_name = cfg.run_name + state = fabric.load(cfg.checkpoint.resume_from) + ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) + cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + cfg.checkpoint.resume_from = str(ckpt_path) + cfg.per_rank_batch_size = state["batch_size"] // (world_collective.world_size - 1) + cfg.root_dir = root_dir + cfg.run_name = f"resume_from_checkpoint_{run_name}" + root_dir = ( os.path.join("logs", "runs", cfg.root_dir) if cfg.root_dir is not None @@ -43,16 +64,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co cfg.run_name if cfg.run_name is not None else f"{cfg.env.id}_{cfg.exp_name}_{cfg.seed}_{int(time.time())}" ) logger = TensorBoardLogger(root_dir=root_dir, name=run_name) + fabric._loggers = [logger] logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) - # Initialize Fabric object - fabric = Fabric(loggers=logger, callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() - device = fabric.device - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic - # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv envs = vectorized_env( @@ -120,14 +134,14 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, list(agent.parameters())) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - "Time/step_per_second": MeanMetric(sync_on_compute=False), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=False), + "Game/ep_len_avg": MeanMetric(sync_on_compute=False), + "Time/step_per_second": MeanMetric(sync_on_compute=False), + } + ) + aggregator.to(device) # Local data rb = ReplayBuffer( @@ -140,9 +154,10 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 - last_log = 0 - last_checkpoint = 0 + start_step = state["update"] if cfg.checkpoint.resume_from else 1 + policy_step = (state["update"] - 1) * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 + last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 @@ -191,7 +206,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co next_obs[k] = torch_obs next_done = torch.zeros(cfg.env.num_envs, 1, dtype=torch.float32) # [N_envs, 1] - for update in range(1, num_updates + 1): + params = {"update": start_step, "last_log": last_log, "last_checkpoint": last_checkpoint} + world_collective.scatter_object_list([None], [params] * world_collective.world_size, src=0) + for update in range(start_step, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs @@ -351,15 +368,24 @@ def trainer( fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic + # Resume from checkpoint + if cfg.checkpoint.resume_from: + state = fabric.load(cfg.checkpoint.resume_from) + # Environment setup agent_args = [None] world_collective.broadcast_object_list(agent_args, src=0) - # Create the actor and critic models + # Define the agent and the optimizer agent = PPOAgent(**agent_args[0]) - - # Define the agent and the optimizer and setup them with Fabric optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters()) + + # Load the state from the checkpoint + if cfg.checkpoint.resume_from: + agent.load_state_dict(state["agent"]) + optimizer.load_state_dict(state["optimizer"]) + + # Setup agent and optimizer with Fabric agent = fabric.setup_module(agent) optimizer = fabric.setup_optimizers(optimizer) @@ -380,30 +406,30 @@ def trainer( from torch.optim.lr_scheduler import PolynomialLR scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) + if cfg.checkpoint.resume_from: + scheduler.load_state_dict(state["scheduler"]) # Metrics - with fabric.device: - aggregator = MetricAggregator( - { - "Loss/value_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/policy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/entropy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - } - ) + aggregator = MetricAggregator( + { + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + } + ) + aggregator.to(device) # Start training - update = 0 + policy_steps_per_update = cfg.env.num_envs * cfg.algo.rollout_steps + params = [None] + world_collective.scatter_object_list(params, [None for _ in range(world_collective.world_size)], src=0) + params = params[0] + update = params["update"] initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) - policy_step = 0 - last_log = 0 - last_checkpoint = 0 + policy_step = update * policy_steps_per_update + last_log = params["last_log"] + last_checkpoint = params["last_checkpoint"] while True: # Wait for data data = [None] @@ -415,14 +441,15 @@ def trainer( state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), - "update_step": update, "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, + "update": update, + "batch_size": cfg.per_rank_batch_size * (world_collective.world_size - 1), + "last_log": last_log, + "last_checkpoint": last_checkpoint, } fabric.call("on_checkpoint_trainer", player_trainer_collective=player_trainer_collective, state=state) return data = make_tensordict(data, device=device) - update += 1 - policy_step += cfg.env.num_envs * cfg.algo.rollout_steps # Prepare sampler indexes = list(range(data.shape[0])) @@ -522,10 +549,15 @@ def trainer( state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), - "update_step": update, "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, + "update": update, + "batch_size": cfg.per_rank_batch_size * (world_collective.world_size - 1), + "last_log": last_log, + "last_checkpoint": last_checkpoint, } fabric.call("on_checkpoint_trainer", player_trainer_collective=player_trainer_collective, state=state) + update += 1 + policy_step += cfg.env.num_envs * cfg.algo.rollout_steps @register_algorithm(decoupled=True) diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index fab5f318..2207c6fb 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -1,6 +1,7 @@ import copy import itertools import os +import pathlib import time import warnings from contextlib import nullcontext @@ -125,6 +126,18 @@ def main(cfg: DictConfig): fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic + # Resume from checkpoint + if cfg.checkpoint.resume_from: + root_dir = cfg.root_dir + run_name = cfg.run_name + state = fabric.load(cfg.checkpoint.resume_from) + ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) + cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + cfg.checkpoint.resume_from = str(ckpt_path) + cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.root_dir = root_dir + cfg.run_name = f"resume_from_checkpoint_{run_name}" + # Create TensorBoardLogger. This will create the logger only on the # rank-0 process logger, log_dir = create_tensorboard_logger(fabric, cfg, "ppo_recurrent") @@ -157,34 +170,41 @@ def main(cfg: DictConfig): f"Provided environment: {cfg.env.id}" ) - # Define the agent and the optimizer and setup them with Fabric + # Define the agent and the optimizer obs_dim = prod(envs.single_observation_space.shape) - agent = fabric.setup_module( - RecurrentPPOAgent( - observation_dim=obs_dim, - action_dim=envs.single_action_space.n, - lstm_hidden_size=cfg.algo.lstm.hidden_size, - actor_hidden_size=cfg.algo.actor.dense_units, - actor_pre_lstm_hidden_size=cfg.algo.actor.pre_lstm_hidden_size, - critic_hidden_size=cfg.algo.critic.dense_units, - critic_pre_lstm_hidden_size=cfg.algo.critic.pre_lstm_hidden_size, - num_envs=cfg.env.num_envs, - ) + agent = RecurrentPPOAgent( + observation_dim=obs_dim, + action_dim=envs.single_action_space.n, + lstm_hidden_size=cfg.algo.lstm.hidden_size, + actor_hidden_size=cfg.algo.actor.dense_units, + actor_pre_lstm_hidden_size=cfg.algo.actor.pre_lstm_hidden_size, + critic_hidden_size=cfg.algo.critic.dense_units, + critic_pre_lstm_hidden_size=cfg.algo.critic.pre_lstm_hidden_size, + num_envs=cfg.env.num_envs, ) - optimizer = fabric.setup_optimizers(hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())) + optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters()) + + # Load the state from the checkpoint + if cfg.checkpoint.resume_from: + agent.load_state_dict(state["agent"]) + optimizer.load_state_dict(state["optimizer"]) + + # Setup agent and optimizer with Fabric + agent = fabric.setup_module(agent) + optimizer = fabric.setup_optimizers(optimizer) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ) + aggregator.to(device) # Local data rb = ReplayBuffer( @@ -197,9 +217,10 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[1, cfg.env.num_envs], device=device) # Global variables - policy_step = 0 - last_log = 0 - last_checkpoint = 0 + start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 + last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 @@ -226,6 +247,8 @@ def main(cfg: DictConfig): from torch.optim.lr_scheduler import PolynomialLR scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) + if cfg.checkpoint.resume_from: + scheduler.load_state_dict(state["scheduler"]) with device: # Get the first environment observation and start the optimization @@ -233,7 +256,7 @@ def main(cfg: DictConfig): next_done = torch.zeros(1, cfg.env.num_envs, 1, dtype=torch.float32) # [1, N_envs, 1] next_state = agent.initial_states - for update in range(1, num_updates + 1): + for update in range(start_step, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs * world_size @@ -371,8 +394,11 @@ def main(cfg: DictConfig): state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), - "update_step": update, "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "last_log": last_log, + "last_checkpoint": last_checkpoint, } ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index 3cdf0e61..0cc1dfd5 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -162,7 +162,22 @@ def __init__( # Actor and critics self._num_critics = len(critics) - self._actor = actor + self.actor = actor + self.critics = critics + + # Automatic entropy tuning + self._target_entropy = torch.tensor(target_entropy, device=device) + self._log_alpha = torch.nn.Parameter(torch.log(torch.tensor([alpha], device=device)), requires_grad=True) + + # EMA tau + self._tau = tau + + @property + def critics(self) -> nn.ModuleList: + return self.qfs + + @critics.setter + def critics(self, critics: Sequence[Union[SACCritic, _FabricModule]]) -> None: self._qfs = nn.ModuleList(critics) # Create target critic unwrapping the DDP module from the critics to prevent @@ -181,13 +196,7 @@ def __init__( self._qfs_target = copy.deepcopy(self._qfs_unwrapped) for p in self._qfs_target.parameters(): p.requires_grad = False - - # Automatic entropy tuning - self._target_entropy = torch.tensor(target_entropy, device=device) - self._log_alpha = torch.nn.Parameter(torch.log(torch.tensor([alpha], device=device)), requires_grad=True) - - # EMA tau - self._tau = tau + return @property def num_critics(self) -> int: @@ -205,6 +214,10 @@ def qfs_unwrapped(self) -> nn.ModuleList: def actor(self) -> Union[SACActor, _FabricModule]: return self._actor + @actor.setter + def actor(self, actor: Union[SACActor, _FabricModule]) -> None: + self._actor = actor + @property def qfs_target(self) -> nn.ModuleList: return self._qfs_target diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 04bc171b..966f71a0 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -1,4 +1,5 @@ import os +import pathlib import time import warnings from math import prod @@ -45,10 +46,7 @@ def train( ): # Update the soft-critic next_target_qf_value = agent.get_next_target_q_values( - data["next_observations"], - data["rewards"], - data["dones"], - cfg.algo.gamma, + data["next_observations"], data["rewards"], data["dones"], cfg.algo.gamma ) qf_values = agent.get_q_values(data["observations"], data["actions"]) qf_loss = critic_loss(qf_values, next_target_qf_value, agent.num_critics) @@ -92,6 +90,18 @@ def main(cfg: DictConfig): fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic + # Resume from checkpoint + if cfg.checkpoint.resume_from: + root_dir = cfg.root_dir + run_name = cfg.run_name + state = fabric.load(cfg.checkpoint.resume_from) + ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) + cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + cfg.checkpoint.resume_from = str(ckpt_path) + cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.root_dir = root_dir + cfg.run_name = f"resume_from_checkpoint_{run_name}" + # Create TensorBoardLogger. This will create the logger only on the # rank-0 process logger, log_dir = create_tensorboard_logger(fabric, cfg, "sac") @@ -124,46 +134,51 @@ def main(cfg: DictConfig): f"Provided environment: {cfg.env.id}" ) - # Define the agent and the optimizer and setup them with Fabric + # Define the agent and the optimizer and setup sthem with Fabric act_dim = prod(envs.single_action_space.shape) obs_dim = prod(envs.single_observation_space.shape) - actor = fabric.setup_module( - SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - hidden_size=cfg.algo.actor.hidden_size, - action_low=envs.single_action_space.low, - action_high=envs.single_action_space.high, - ) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + hidden_size=cfg.algo.actor.hidden_size, + action_low=envs.single_action_space.low, + action_high=envs.single_action_space.high, ) critics = [ - fabric.setup_module( - SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) - ) + SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) for _ in range(cfg.algo.critic.n) ] target_entropy = -act_dim agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) + if cfg.checkpoint.resume_from: + agent.load_state_dict(state["agent"]) + agent.actor = fabric.setup_module(agent.actor) + agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers + qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) + actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters()) + alpha_optimizer = hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]) + if cfg.checkpoint.resume_from: + qf_optimizer.load_state_dict(state["qf_optimizer"]) + actor_optimizer.load_state_dict(state["actor_optimizer"]) + alpha_optimizer.load_state_dict(state["alpha_optimizer"]) qf_optimizer, actor_optimizer, alpha_optimizer = fabric.setup_optimizers( - hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()), - hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters()), - hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]), + qf_optimizer, actor_optimizer, alpha_optimizer ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ) + aggregator.to(device) # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1 @@ -174,16 +189,26 @@ def main(cfg: DictConfig): memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) + if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: + if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + rb = state["rb"][fabric.global_rank] + elif isinstance(state["rb"], ReplayBuffer): + rb = state["rb"] + else: + raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 - last_log = 0 - last_checkpoint = 0 + start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 + last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 + if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: + learning_starts += start_step # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: @@ -202,20 +227,24 @@ def main(cfg: DictConfig): ) # Get the first environment observation and start the optimization - obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32, device=device) # [N_envs, N_obs] + obs = torch.tensor( + envs.reset(seed=cfg.seed)[0], + dtype=torch.float32, + device=device, + ) # [N_envs, N_obs] - for update in range(1, num_updates + 1): + for update in range(start_step, num_updates + 1): if update < learning_starts: actions = envs.action_space.sample() else: # Sample an action given the observation received by the environment with torch.no_grad(): - actions, _ = actor.module(obs) + actions, _ = agent.actor.module(obs) actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions) dones = np.logical_or(dones, truncated) - policy_step += cfg.env.num_envs * fabric.world_size + policy_step += policy_steps_per_update if "final_info" in infos: for i, agent_final_info in enumerate(infos["final_info"]): @@ -308,7 +337,10 @@ def main(cfg: DictConfig): "qf_optimizer": qf_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), - "update": update, + "update": update * fabric.world_size, + "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "last_log": last_log, + "last_checkpoint": last_checkpoint, } ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") fabric.call( @@ -331,7 +363,7 @@ def main(cfg: DictConfig): mask_velocities=False, vector_env_idx=0, )() - test(actor.module, test_env, fabric, cfg) + test(agent.actor.module, test_env, fabric, cfg) if __name__ == "__main__": diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index b0016b77..4c7863e3 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -1,4 +1,5 @@ import os +import pathlib import time import warnings from datetime import datetime, timedelta @@ -32,6 +33,27 @@ @torch.no_grad() def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective): + # Initialize Fabric + fabric = Fabric(callbacks=[CheckpointCallback()]) + if not _is_using_cli(): + fabric.launch() + rank = fabric.global_rank + device = fabric.device + fabric.seed_everything(cfg.seed) + torch.backends.cudnn.deterministic = cfg.torch_deterministic + + # Resume from checkpoint + if cfg.checkpoint.resume_from: + root_dir = cfg.root_dir + run_name = cfg.run_name + state = fabric.load(cfg.checkpoint.resume_from) + ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) + cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + cfg.checkpoint.resume_from = str(ckpt_path) + cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.root_dir = root_dir + cfg.run_name = f"resume_from_checkpoint_{run_name}" + root_dir = ( os.path.join("logs", "runs", cfg.root_dir) if cfg.root_dir is not None @@ -41,17 +63,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co cfg.run_name if cfg.run_name is not None else f"{cfg.env.id}_{cfg.exp_name}_{cfg.seed}_{int(time.time())}" ) logger = TensorBoardLogger(root_dir=root_dir, name=run_name) + fabric._loggers = [logger] logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) - # Initialize Fabric - fabric = Fabric(loggers=logger, callbacks=[CheckpointCallback()]) - if not _is_using_cli(): - fabric.launch() - rank = fabric.global_rank - device = fabric.device - fabric.seed_everything(cfg.seed) - torch.backends.cudnn.deterministic = cfg.torch_deterministic - # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv envs = vectorized_env( @@ -100,14 +114,14 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, actor.parameters()) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - "Time/step_per_second": MeanMetric(sync_on_compute=False), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=False), + "Game/ep_len_avg": MeanMetric(sync_on_compute=False), + "Time/step_per_second": MeanMetric(sync_on_compute=False), + } + ) + aggregator.to(device) # Local data buffer_size = cfg.buffer.size // cfg.env.num_envs if not cfg.dry_run else 1 @@ -118,16 +132,26 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co memmap=cfg.buffer.memmap, memmap_dir=os.path.join(logger.log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) + if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: + if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + rb = state["rb"][fabric.global_rank] + elif isinstance(state["rb"], ReplayBuffer): + rb = state["rb"] + else: + raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 - last_log = 0 - last_checkpoint = 0 + start_step = state["update"] if cfg.checkpoint.resume_from else 1 + policy_step = (state["update"] - 1) * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 + last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 + if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: + learning_starts += start_step # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: @@ -149,7 +173,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Get the first environment observation and start the optimization obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs] - for update in range(1, num_updates + 1): + for update in range(start_step, num_updates + 1): if update < learning_starts: actions = envs.action_space.sample() else: @@ -197,6 +221,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Send data to the training agents if update >= learning_starts: + if update == learning_starts: + params = {"update": update, "last_log": last_log, "last_checkpoint": last_checkpoint} + world_collective.scatter_object_list([None], [params] * world_collective.world_size, src=0) training_steps = learning_starts if update == learning_starts else 1 chunks = rb.sample( training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size * (fabric.world_size - 1), @@ -273,7 +300,6 @@ def trainer( optimization_pg: CollectibleGroup, ): global_rank = world_collective.rank - global_rank - 1 # Receive (possibly updated, by the make_dict_env method for example) cfg from the player data = [None] @@ -288,6 +314,14 @@ def trainer( fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic + # Resume from checkpoint + if cfg.checkpoint.resume_from: + state = fabric.load(cfg.checkpoint.resume_from) + ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) + cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + cfg.checkpoint.resume_from = str(ckpt_path) + cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv envs = vectorized_env([make_env(cfg.env.id, 0, 0, False, None, mask_velocities=False)]) @@ -296,60 +330,65 @@ def trainer( # Define the agent and the optimizer and setup them with Fabric act_dim = prod(envs.single_action_space.shape) obs_dim = prod(envs.single_observation_space.shape) - actor = fabric.setup_module( - SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - hidden_size=cfg.algo.actor.hidden_size, - action_low=envs.single_action_space.low, - action_high=envs.single_action_space.high, - ) + + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + hidden_size=cfg.algo.actor.hidden_size, + action_low=envs.single_action_space.low, + action_high=envs.single_action_space.high, ) critics = [ - fabric.setup_module( - SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) - ) + SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) for _ in range(cfg.algo.critic.n) ] target_entropy = -act_dim agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) + if cfg.checkpoint.resume_from: + agent.load_state_dict(state["agent"]) + agent.actor = fabric.setup_module(agent.actor) + agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers + qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) + actor_optimizer = hydra.utils.instantiate( + cfg.algo.actor.optimizer, + params=agent.actor.parameters(), + ) + alpha_optimizer = hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]) + if cfg.checkpoint.resume_from: + qf_optimizer.load_state_dict(state["qf_optimizer"]) + actor_optimizer.load_state_dict(state["actor_optimizer"]) + alpha_optimizer.load_state_dict(state["alpha_optimizer"]) qf_optimizer, actor_optimizer, alpha_optimizer = fabric.setup_optimizers( - hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()), - hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters()), - hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]), + qf_optimizer, actor_optimizer, alpha_optimizer ) # Send weights to rank-0, a.k.a. the player if global_rank == 1: player_trainer_collective.broadcast( - torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()), src=1 + torch.nn.utils.convert_parameters.parameters_to_vector(agent.actor.parameters()), src=1 ) # Metrics - with fabric.device: - aggregator = MetricAggregator( - { - "Loss/value_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/policy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/alpha_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - } - ) + aggregator = MetricAggregator( + { + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + } + ) + aggregator.to(device) # Start training policy_steps_per_update = cfg.env.num_envs - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - update = learning_starts + params = [None] + world_collective.scatter_object_list(params, [None for _ in range(world_collective.world_size)], src=0) + params = params[0] + update = params["update"] policy_step = update * policy_steps_per_update - last_log = (policy_step // cfg.metric.log_every - 1) * cfg.metric.log_every - last_checkpoint = 0 + last_log = params["last_log"] + last_checkpoint = params["last_checkpoint"] while True: # Wait for data data = [None] @@ -364,6 +403,9 @@ def trainer( "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), "update": update, + "batch_size": cfg.per_rank_batch_size * (world_collective.world_size - 1), + "last_log": last_log, + "last_checkpoint": last_checkpoint, } fabric.call("on_checkpoint_trainer", player_trainer_collective=player_trainer_collective, state=state) return @@ -396,7 +438,7 @@ def trainer( if global_rank == 1: player_trainer_collective.broadcast( - torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()), src=1 + torch.nn.utils.convert_parameters.parameters_to_vector(agent.actor.parameters()), src=1 ) # Checkpoint model on rank-0: send it everything @@ -409,6 +451,9 @@ def trainer( "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), "update": update, + "batch_size": cfg.per_rank_batch_size * (world_collective.world_size - 1), + "last_log": last_log, + "last_checkpoint": last_checkpoint, } fabric.call("on_checkpoint_trainer", player_trainer_collective=player_trainer_collective, state=state) update += 1 diff --git a/sheeprl/configs/env_config.yaml b/sheeprl/configs/env_config.yaml index d4994dff..5f7fdb73 100644 --- a/sheeprl/configs/env_config.yaml +++ b/sheeprl/configs/env_config.yaml @@ -5,6 +5,10 @@ defaults: - _self_ - env: default.yaml +hydra: + run: + dir: logs/envs/${env.id}/${agent} + seed: 42 exp_name: "default" root_dir: $env_logs From a9e940a49f7b611b7a9ce381139a56159001430a Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 13 Sep 2023 11:38:25 +0200 Subject: [PATCH 2/8] fix: sac_decoupled --- sheeprl/algos/sac/sac_decoupled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 4c7863e3..90132a1d 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -221,7 +221,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Send data to the training agents if update >= learning_starts: - if update == learning_starts: + if update == learning_starts or cfg.dry_run: params = {"update": update, "last_log": last_log, "last_checkpoint": last_checkpoint} world_collective.scatter_object_list([None], [params] * world_collective.world_size, src=0) training_steps = learning_starts if update == learning_starts else 1 From 9190b58077ef1225b173c847be8e1a8b98e249ba Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 13 Sep 2023 11:38:41 +0200 Subject: [PATCH 3/8] tests: update tests --- tests/test_algos/test_algos.py | 46 +++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index b5cb20bd..414d9fa2 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -98,7 +98,16 @@ def test_droq(standard_args, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "update"} + keys = { + "agent", + "qf_optimizer", + "actor_optimizer", + "alpha_optimizer", + "update", + "last_log", + "last_checkpoint", + "batch_size", + } if checkpoint_buffer: keys.add("rb") check_checkpoint(Path(os.path.join("logs", "runs", ckpt_path)), keys, checkpoint_buffer) @@ -131,7 +140,16 @@ def test_sac(standard_args, checkpoint_buffer, start_time): if command == "main": task.__dict__[command]() - keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "update"} + keys = { + "agent", + "qf_optimizer", + "actor_optimizer", + "alpha_optimizer", + "update", + "last_log", + "last_checkpoint", + "batch_size", + } if checkpoint_buffer: keys.add("rb") check_checkpoint(Path(os.path.join("logs", "runs", ckpt_path)), keys, checkpoint_buffer) @@ -233,7 +251,16 @@ def test_sac_decoupled(standard_args, checkpoint_buffer, start_time): torchrun.main(torchrun_args) if os.environ["LT_DEVICES"] != "1": - keys = {"agent", "qf_optimizer", "actor_optimizer", "alpha_optimizer", "update"} + keys = { + "agent", + "qf_optimizer", + "actor_optimizer", + "alpha_optimizer", + "update", + "last_log", + "last_checkpoint", + "batch_size", + } if checkpoint_buffer: keys.add("rb") check_checkpoint(Path(os.path.join("logs", "runs", ckpt_path)), keys, checkpoint_buffer) @@ -264,7 +291,10 @@ def test_ppo(standard_args, start_time, env_id): if command == "main": task.__dict__[command]() - check_checkpoint(Path(os.path.join("logs", "runs", ckpt_path)), {"agent", "optimizer", "update_step", "scheduler"}) + check_checkpoint( + Path(os.path.join("logs", "runs", ckpt_path)), + {"agent", "optimizer", "update", "scheduler", "last_log", "last_checkpoint", "batch_size"}, + ) remove_test_dir(os.path.join("logs", "runs", f"pytest_{start_time}")) @@ -312,7 +342,8 @@ def test_ppo_decoupled(standard_args, start_time, env_id): if os.environ["LT_DEVICES"] != "1": check_checkpoint( - Path(os.path.join("logs", "runs", ckpt_path)), {"agent", "optimizer", "update_step", "scheduler"} + Path(os.path.join("logs", "runs", ckpt_path)), + {"agent", "optimizer", "update", "scheduler", "last_log", "last_checkpoint", "batch_size"}, ) remove_test_dir(os.path.join("logs", "runs", f"pytest_{start_time}")) @@ -338,7 +369,10 @@ def test_ppo_recurrent(standard_args, start_time): if command == "main": task.__dict__[command]() - check_checkpoint(Path(os.path.join("logs", "runs", ckpt_path)), {"agent", "optimizer", "update_step", "scheduler"}) + check_checkpoint( + Path(os.path.join("logs", "runs", ckpt_path)), + {"agent", "optimizer", "update", "scheduler", "last_log", "last_checkpoint", "batch_size"}, + ) remove_test_dir(os.path.join("logs", "runs", f"pytest_{start_time}")) From 8362e49233622f9c244a4e966feb21baaa599a55 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Thu, 14 Sep 2023 16:45:25 +0200 Subject: [PATCH 4/8] feat: added resume from checkpoint for sac-ae --- sheeprl/algos/droq/agent.py | 6 +++ sheeprl/algos/sac/agent.py | 7 +++ sheeprl/algos/sac_ae/agent.py | 48 +++++++++++------ sheeprl/algos/sac_ae/sac_ae.py | 95 +++++++++++++++++++++++----------- tests/test_algos/test_algos.py | 2 + 5 files changed, 113 insertions(+), 45 deletions(-) diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py index cfb817a0..537c3ba9 100644 --- a/sheeprl/algos/droq/agent.py +++ b/sheeprl/algos/droq/agent.py @@ -93,6 +93,12 @@ def __init__( # EMA tau self._tau = tau + def __setattr__(self, name, value): + if name in vars(type(self)) and isinstance(vars(type(self))[name], property): + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) + @property def critics(self) -> nn.ModuleList: return self.qfs diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index 0cc1dfd5..0c16e2e5 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -172,6 +172,12 @@ def __init__( # EMA tau self._tau = tau + def __setattr__(self, name, value): + if name in vars(type(self)) and isinstance(vars(type(self))[name], property): + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) + @property def critics(self) -> nn.ModuleList: return self.qfs @@ -217,6 +223,7 @@ def actor(self) -> Union[SACActor, _FabricModule]: @actor.setter def actor(self, actor: Union[SACActor, _FabricModule]) -> None: self._actor = actor + return @property def qfs_target(self) -> nn.ModuleList: diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index 0b910957..ae1f824e 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -337,22 +337,8 @@ def __init__( # Actor and critics self._num_critics = len(critic.qfs) - self._actor = actor - self._critic = critic - - # Create target critic unwrapping the DDP module from the critics to prevent - # `RuntimeError: DDP Pickling/Unpickling are only supported when using DDP with the default process group. - # That is, when you have called init_process_group and have not passed process_group - # argument to DDP constructor`. - # This happens when we're using the decoupled version of SACAE for example - if hasattr(critic, "module"): - critic_module = critic.module - else: - critic_module = critic - self._critic_unwrapped = critic_module - self._critic_target = copy.deepcopy(self._critic_unwrapped) - for p in self._critic_target.parameters(): - p.requires_grad = False + self.actor = actor + self.critic = critic # Automatic entropy tuning self._target_entropy = torch.tensor(target_entropy, device=device) @@ -370,6 +356,31 @@ def num_critics(self) -> int: def critic(self) -> Union[SACAECritic, _FabricModule]: return self._critic + def __setattr__(self, name, value): + if name in vars(type(self)) and isinstance(vars(type(self))[name], property): + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) + + @critic.setter + def critic(self, critic: Union[SACAECritic, _FabricModule]) -> None: + self._critic = critic + + # Create target critic unwrapping the DDP module from the critics to prevent + # `RuntimeError: DDP Pickling/Unpickling are only supported when using DDP with the default process group. + # That is, when you have called init_process_group and have not passed process_group + # argument to DDP constructor`. + # This happens when we're using the decoupled version of SACAE for example + if hasattr(critic, "module"): + critic_module = critic.module + else: + critic_module = critic + self._critic_unwrapped = critic_module + self._critic_target = copy.deepcopy(self._critic_unwrapped) + for p in self._critic_target.parameters(): + p.requires_grad = False + return + @property def critic_unwrapped(self) -> SACAECritic: return self._critic_unwrapped @@ -378,6 +389,11 @@ def critic_unwrapped(self) -> SACAECritic: def actor(self) -> Union[SACAEContinuousActor, _FabricModule]: return self._actor + @actor.setter + def actor(self, actor: Union[SACAEContinuousActor, _FabricModule]) -> None: + self._actor = actor + return + @property def critic_target(self) -> SACAECritic: return self._critic_target diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index df66d59d..9e8f3856 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -1,5 +1,6 @@ import copy import os +import pathlib import time import warnings from math import prod @@ -169,6 +170,18 @@ def main(cfg: DictConfig): fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic + # Resume from checkpoint + if cfg.checkpoint.resume_from: + root_dir = cfg.root_dir + run_name = cfg.run_name + state = fabric.load(cfg.checkpoint.resume_from) + ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) + cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + cfg.checkpoint.resume_from = str(ckpt_path) + cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.root_dir = root_dir + cfg.run_name = f"resume_from_checkpoint_{run_name}" + # Create TensorBoardLogger. This will create the logger only on the # rank-0 process logger, log_dir = create_tensorboard_logger(fabric, cfg, "sac_ae") @@ -279,13 +292,11 @@ def main(cfg: DictConfig): else None ) decoder = MultiDecoder(cnn_decoder, mlp_decoder) - encoder = fabric.setup_module(encoder) - decoder = fabric.setup_module(decoder) # Setup actor and critic. Those will initialize with orthogonal weights # both the actor and critic actor = SACAEContinuousActor( - encoder=copy.deepcopy(encoder.module), + encoder=copy.deepcopy(encoder), action_dim=act_dim, hidden_size=cfg.algo.actor.hidden_size, action_low=envs.single_action_space.low, @@ -297,9 +308,7 @@ def main(cfg: DictConfig): ) for _ in range(cfg.algo.critic.n) ] - critic = SACAECritic(encoder=encoder.module, qfs=qfs) - actor = fabric.setup_module(actor) - critic = fabric.setup_module(critic) + critic = SACAECritic(encoder=encoder, qfs=qfs) # The agent will tied convolutional and linear weights between the encoder actor and critic agent = SACAEAgent( @@ -313,28 +322,44 @@ def main(cfg: DictConfig): ) # Optimizers + qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.critic.parameters()) + actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters()) + alpha_optimizer = hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]) + encoder_optimizer = hydra.utils.instantiate(cfg.algo.encoder.optimizer, params=encoder.parameters()) + decoder_optimizer = hydra.utils.instantiate(cfg.algo.decoder.optimizer, params=decoder.parameters()) + + if cfg.checkpoint.resume_from: + agent.load_state_dict(state["agent"]) + encoder.load_state_dict(state["encoder"]) + decoder.load_state_dict(state["decoder"]) + qf_optimizer.load_state_dict(state["qf_optimizer"]) + actor_optimizer.load_state_dict(state["actor_optimizer"]) + alpha_optimizer.load_state_dict(state["alpha_optimizer"]) + encoder_optimizer.load_state_dict(state["encoder_optimizer"]) + decoder_optimizer.load_state_dict(state["decoder_optimizer"]) + + encoder = fabric.setup_module(encoder) + decoder = fabric.setup_module(decoder) + agent.actor = fabric.setup_module(agent.actor) + agent.critic = fabric.setup_module(agent.critic) qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer = fabric.setup_optimizers( - hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.critic.parameters()), - hydra.utils.instantiate(cfg.algo.actor.optimizer, params=agent.actor.parameters()), - hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]), - hydra.utils.instantiate(cfg.algo.encoder.optimizer, params=encoder.parameters()), - hydra.utils.instantiate(cfg.algo.decoder.optimizer, params=decoder.parameters()), + qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ) + aggregator.to(device) # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1 @@ -346,16 +371,26 @@ def main(cfg: DictConfig): memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), obs_keys=cfg.cnn_keys.encoder + cfg.mlp_keys.encoder, ) + if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: + if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + rb = state["rb"][fabric.global_rank] + elif isinstance(state["rb"], ReplayBuffer): + rb = state["rb"] + else: + raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=fabric.device if cfg.buffer.memmap else "cpu") # Global variables - policy_step = 0 - last_log = 0 - last_checkpoint = 0 + start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 + last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 start_time = time.time() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 + if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: + learning_starts += start_step # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: @@ -385,13 +420,13 @@ def main(cfg: DictConfig): torch_obs = torch_obs.float() obs[k] = torch_obs - for update in range(1, num_updates + 1): + for update in range(start_step, num_updates + 1): if update < learning_starts: actions = envs.action_space.sample() else: with torch.no_grad(): normalized_obs = {k: v / 255 if k in cfg.cnn_keys.encoder else v for k, v in obs.items()} - actions, _ = actor.module(normalized_obs) + actions, _ = agent.actor.module(normalized_obs) actions = actions.cpu().numpy() o, rewards, dones, truncated, infos = envs.step(actions) dones = np.logical_or(dones, truncated) @@ -508,6 +543,8 @@ def main(cfg: DictConfig): "decoder_optimizer": decoder_optimizer.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "last_log": last_log, + "last_checkpoint": last_checkpoint, } ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") fabric.call( @@ -521,7 +558,7 @@ def main(cfg: DictConfig): envs.close() if fabric.is_global_zero: test_env = make_dict_env(cfg, cfg.seed, 0, fabric.logger.log_dir, "test", vector_env_idx=0)() - test_sac_ae(actor.module, test_env, fabric, cfg) + test_sac_ae(agent.actor.module, test_env, fabric, cfg) if __name__ == "__main__": diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 414d9fa2..797341d3 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -201,6 +201,8 @@ def test_sac_ae(standard_args, checkpoint_buffer, start_time): "decoder_optimizer", "update", "batch_size", + "last_log", + "last_checkpoint", } if checkpoint_buffer: keys.add("rb") From d3c681014f0935dfb8c29e455dd99bcae7490d58 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 15 Sep 2023 11:08:51 +0200 Subject: [PATCH 5/8] fix: __setattr__ in sac, sac-ae and droq --- sheeprl/algos/droq/agent.py | 12 +++++++----- sheeprl/algos/sac/agent.py | 12 +++++++----- sheeprl/algos/sac_ae/agent.py | 12 +++++++----- sheeprl/algos/sac_ae/sac_ae.py | 2 +- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py index 537c3ba9..a3c88b12 100644 --- a/sheeprl/algos/droq/agent.py +++ b/sheeprl/algos/droq/agent.py @@ -93,11 +93,13 @@ def __init__( # EMA tau self._tau = tau - def __setattr__(self, name, value): - if name in vars(type(self)) and isinstance(vars(type(self))[name], property): - object.__setattr__(self, name, value) - else: - super().__setattr__(name, value) + def __setattr__(self, name: str, value: Union[Tensor, nn.Module]) -> None: + # Taken from https://github.com/pytorch/pytorch/pull/92044 + # Check if a property setter exists. If it does, use it. + class_attr = getattr(self.__class__, name, None) + if isinstance(class_attr, property) and class_attr.fset is not None: + return class_attr.fset(self, value) + super().__setattr__(name, value) @property def critics(self) -> nn.ModuleList: diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index 0c16e2e5..2f37f068 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -172,11 +172,13 @@ def __init__( # EMA tau self._tau = tau - def __setattr__(self, name, value): - if name in vars(type(self)) and isinstance(vars(type(self))[name], property): - object.__setattr__(self, name, value) - else: - super().__setattr__(name, value) + def __setattr__(self, name: str, value: Union[Tensor, nn.Module]) -> None: + # Taken from https://github.com/pytorch/pytorch/pull/92044 + # Check if a property setter exists. If it does, use it. + class_attr = getattr(self.__class__, name, None) + if isinstance(class_attr, property) and class_attr.fset is not None: + return class_attr.fset(self, value) + super().__setattr__(name, value) @property def critics(self) -> nn.ModuleList: diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index ae1f824e..2e25e280 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -356,11 +356,13 @@ def num_critics(self) -> int: def critic(self) -> Union[SACAECritic, _FabricModule]: return self._critic - def __setattr__(self, name, value): - if name in vars(type(self)) and isinstance(vars(type(self))[name], property): - object.__setattr__(self, name, value) - else: - super().__setattr__(name, value) + def __setattr__(self, name: str, value: Union[Tensor, nn.Module]) -> None: + # Taken from https://github.com/pytorch/pytorch/pull/92044 + # Check if a property setter exists. If it does, use it. + class_attr = getattr(self.__class__, name, None) + if isinstance(class_attr, property) and class_attr.fset is not None: + return class_attr.fset(self, value) + super().__setattr__(name, value) @critic.setter def critic(self, critic: Union[SACAECritic, _FabricModule]) -> None: diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 9e8f3856..d89301a4 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -531,7 +531,7 @@ def main(cfg: DictConfig): # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run: - last_checkpoint >= policy_step + last_checkpoint = policy_step state = { "agent": agent.state_dict(), "encoder": encoder.state_dict(), From b8dd0df9ad53b903263f4f14f93fb764d6b7feec Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 15 Sep 2023 12:11:45 +0200 Subject: [PATCH 6/8] fix: sac and sac_ae --- sheeprl/algos/sac/sac.py | 2 +- sheeprl/algos/sac_ae/sac_ae.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 19115ee8..f17569bb 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -250,7 +250,7 @@ def main(cfg: DictConfig): else: # Sample an action given the observation received by the environment with torch.no_grad(): - actions, _ = actor.module(obs) + actions, _ = agent.actor.module(obs) actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions) dones = np.logical_or(dones, truncated) diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index d57d8ba2..46dc3784 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -436,7 +436,7 @@ def main(cfg: DictConfig): else: with torch.no_grad(): normalized_obs = {k: v / 255 if k in cfg.cnn_keys.encoder else v for k, v in obs.items()} - actions, _ = actor.module(normalized_obs) + actions, _ = agent.actor.module(normalized_obs) actions = actions.cpu().numpy() o, rewards, dones, truncated, infos = envs.step(actions) dones = np.logical_or(dones, truncated) From dbc13458ee71e231ede687af20f267900af69927 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 15 Sep 2023 15:00:01 +0200 Subject: [PATCH 7/8] fix: resume from checkpoint log dir --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 2 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 2 +- sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- sheeprl/algos/droq/droq.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2.py | 2 +- sheeprl/algos/ppo/ppo.py | 2 +- sheeprl/algos/ppo/ppo_decoupled.py | 2 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 2 +- sheeprl/algos/sac/sac.py | 2 +- sheeprl/algos/sac/sac_decoupled.py | 2 +- sheeprl/algos/sac_ae/sac_ae.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 1eea1737..7ea3e16f 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -384,7 +384,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 0d7217eb..2ce31283 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -404,7 +404,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index e1abe5c4..75834ff1 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -358,7 +358,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 5cffdec4..b2917e25 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -150,7 +150,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index f987d558..eb301f40 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -389,7 +389,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index d1d48a0b..c8e10803 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -491,7 +491,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 96b721d6..516b1f84 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -142,7 +142,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 889da385..8d57cc22 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -56,7 +56,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // (world_collective.world_size - 1) cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Initialize logger root_dir = ( diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 338900af..0dda0853 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -137,7 +137,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index f17569bb..4f40a338 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -105,7 +105,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 8ff1097d..7140a73c 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -56,7 +56,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Initialize logger root_dir = ( diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 46dc3784..4e7c5bcf 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -185,7 +185,7 @@ def main(cfg: DictConfig): cfg.checkpoint.resume_from = str(ckpt_path) cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size cfg.root_dir = root_dir - cfg.run_name = f"resume_from_checkpoint_{run_name}" + cfg.run_name = run_name # Create TensorBoardLogger. This will create the logger only on the # rank-0 process From 96c09c4f088e8715b639111a90cc63cd14e09074 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 18 Sep 2023 16:58:25 +0200 Subject: [PATCH 8/8] Resume commented pytest timeout --- tests/test_algos/test_algos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 35e8e50e..27ebe453 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -307,7 +307,7 @@ def test_ppo_decoupled(standard_args, start_time, env_id): remove_test_dir(os.path.join("logs", "runs", f"pytest_{start_time}")) -# @pytest.mark.timeout(60) +@pytest.mark.timeout(60) def test_ppo_recurrent(standard_args, start_time): root_dir = os.path.join(f"pytest_{start_time}", "ppo_recurrent", os.environ["LT_DEVICES"]) run_name = "test_ppo_recurrent"