From 0b9a26524a142ffc528af16a0498bbd6b4b2d98b Mon Sep 17 00:00:00 2001 From: belerico Date: Sat, 16 Mar 2024 18:59:53 +0100 Subject: [PATCH 01/26] Fix PPO player with precision --- sheeprl/algos/ppo/ppo.py | 9 +++++---- sheeprl/algos/ppo/ppo_decoupled.py | 9 ++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 1fa8f187..8fe9577a 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -178,6 +178,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): observation_space, state["agent"] if cfg.checkpoint.resume_from else None, ) + player_agent = _FabricModule(agent.module, precision=fabric._precision) # Define the optimizer optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all") @@ -275,7 +276,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = { k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys } - actions, logprobs, _, values = agent.module(torch_obs) + actions, logprobs, _, values = player_agent(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: @@ -303,7 +304,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): - vals = agent.module.get_value(real_next_obs).cpu().numpy() + vals = player_agent.get_value(real_next_obs).cpu().numpy() rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) @@ -348,7 +349,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - next_values = agent.module.get_value(torch_obs) + _, _, _, next_values = player_agent(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], @@ -445,7 +446,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(agent.module, fabric, cfg, log_dir) + test(player_agent, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index d03db22f..94c317ba 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -12,6 +12,7 @@ from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy +from lightning.fabric.strategies.single_device import SingleDeviceStrategy from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, RandomSampler from torchmetrics import SumMetric @@ -91,6 +92,12 @@ def player( "is_continuous": is_continuous, } agent = PPOAgent(**agent_args).to(device) + agent = SingleDeviceStrategy( + device=fabric.device, + accelerator=fabric.accelerator, + checkpoint_io=fabric._connector.checkpoint_io, + precision=fabric._precision, + ).setup_module(agent) if fabric.is_global_zero: save_configs(cfg, log_dir) @@ -267,7 +274,7 @@ def player( # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - next_values = agent.get_value(torch_obs) + _, _, _, next_values = agent(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], From 8e5376f19e9535fbcd4470059de34c00103324af Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 19 Mar 2024 17:06:57 +0100 Subject: [PATCH 02/26] Fix PPO agent to run with the correct precision plugin --- sheeprl/algos/ppo/agent.py | 47 ++++++---- sheeprl/algos/ppo/ppo.py | 141 +++++++++++++++-------------- sheeprl/algos/ppo/ppo_decoupled.py | 9 +- sheeprl/algos/ppo/utils.py | 5 +- 4 files changed, 104 insertions(+), 98 deletions(-) diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index 6236a3ed..95964672 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn from lightning import Fabric -from lightning.fabric.wrappers import _FabricModule from torch import Tensor from torch.distributions import Distribution, Independent, Normal @@ -63,6 +62,18 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor: return self.model(x) +class PPOActor(nn.Module): + def __init__(self, actor_backbone: torch.nn.Module, actor_heads: torch.nn.ModuleList, is_continuous: bool) -> None: + super().__init__() + self.actor_backbone = actor_backbone + self.actor_heads = actor_heads + self.is_continuous = is_continuous + + def forward(self, x: Tensor) -> List[Tensor]: + x = self.actor_backbone(x) + return [head(x) for head in self.actor_heads] + + class PPOAgent(nn.Module): def __init__( self, @@ -78,6 +89,7 @@ def __init__( is_continuous: bool = False, ): super().__init__() + self.is_continuous = is_continuous self.distribution_cfg = distribution_cfg self.actions_dim = actions_dim in_channels = sum([prod(obs_space[k].shape[:-2]) for k in cnn_keys]) @@ -101,7 +113,6 @@ def __init__( else None ) self.feature_extractor = MultiEncoder(cnn_encoder, mlp_encoder) - self.is_continuous = is_continuous features_dim = self.feature_extractor.output_dim self.critic = MLP( input_dims=features_dim, @@ -115,7 +126,7 @@ def __init__( else None ), ) - self.actor_backbone = ( + actor_backbone = ( MLP( input_dims=features_dim, output_dim=None, @@ -133,21 +144,19 @@ def __init__( else nn.Identity() ) if is_continuous: - self.actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)]) + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)]) else: - self.actor_heads = nn.ModuleList( - [nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim] - ) + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) + self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) def forward( self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None ) -> Tuple[Sequence[Tensor], Tensor, Tensor, Tensor]: feat = self.feature_extractor(obs) - out: Tensor = self.actor_backbone(feat) - pre_dist: List[Tensor] = [head(out) for head in self.actor_heads] + actor_out: List[Tensor] = self.actor(feat.detach()) values = self.critic(feat) if self.is_continuous: - mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1) + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) std = log_std.exp() normal = Independent( Normal(mean, std, validate_args=self.distribution_cfg.validate_args), @@ -170,7 +179,7 @@ def forward( if actions is None: should_append = True actions: List[Tensor] = [] - for i, logits in enumerate(pre_dist): + for i, logits in enumerate(actor_out): actions_dist.append( OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args) ) @@ -191,15 +200,14 @@ def get_value(self, obs: Dict[str, Tensor]) -> Tensor: def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]: feat = self.feature_extractor(obs) - out = self.actor_backbone(feat) - pre_dist: List[Tensor] = [head(out) for head in self.actor_heads] - if self.is_continuous: - return [torch.chunk(pre_dist[0], 2, -1)[0]] + actor_out: List[Tensor] = self.actor(feat) + if self.actor.is_continuous: + return [torch.chunk(actor_out[0], 2, -1)[0]] else: return tuple( [ OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args).mode - for logits in pre_dist + for logits in actor_out ] ) @@ -211,7 +219,7 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, agent_state: Optional[Dict[str, Tensor]] = None, -) -> _FabricModule: +) -> PPOAgent: agent = PPOAgent( actions_dim=actions_dim, obs_space=obs_space, @@ -226,6 +234,7 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) - agent = fabric.setup_module(agent) - + agent.feature_extractor = fabric.setup_module(agent.feature_extractor) + agent.critic = fabric.setup_module(agent.critic) + agent.actor = fabric.setup_module(agent.actor) return agent diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 8fe9577a..c354ad0e 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -24,7 +24,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, save_configs +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, save_configs, unwrap_fabric def train( @@ -178,7 +178,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): observation_space, state["agent"] if cfg.checkpoint.resume_from else None, ) - player_agent = _FabricModule(agent.module, precision=fabric._precision) # Define the optimizer optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all") @@ -264,92 +263,93 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data[k] = next_obs[k][np.newaxis] for update in range(start_step, num_updates + 1): - for _ in range(0, cfg.algo.rollout_steps): - policy_step += cfg.env.num_envs * world_size + with torch.inference_mode(): + for _ in range(0, cfg.algo.rollout_steps): + policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = { k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys } - actions, logprobs, _, values = player_agent(torch_obs) + actions, logprobs, _, values = agent(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() actions = torch.cat(actions, -1).cpu().numpy() - # Single environment step - obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) - truncated_envs = np.nonzero(truncated)[0] - if len(truncated_envs) > 0: - real_next_obs = { - k: torch.empty( - len(truncated_envs), - *observation_space[k].shape, - dtype=torch.float32, - device=device, - ) - for k in obs_keys - } - for i, truncated_env in enumerate(truncated_envs): - for k, v in info["final_observation"][truncated_env].items(): - torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - torch_v = torch_v.view(-1, *v.shape[-2:]) - torch_v = torch_v / 255.0 - 0.5 - real_next_obs[k][i] = torch_v - with torch.no_grad(): - vals = player_agent.get_value(real_next_obs).cpu().numpy() - rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) - rewards = rewards.reshape(cfg.env.num_envs, -1) - - # Update the step data - step_data["dones"] = dones[np.newaxis] - step_data["values"] = values.cpu().numpy()[np.newaxis] - step_data["actions"] = actions[np.newaxis] - step_data["logprobs"] = logprobs.cpu().numpy()[np.newaxis] - step_data["rewards"] = rewards[np.newaxis] - if cfg.buffer.memmap: - step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) - step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) - - # Append data to buffer - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Update the observation and dones - next_obs = {} - for k in obs_keys: - _obs = obs[k] - if k in cfg.algo.cnn_keys.encoder: - _obs = _obs.reshape(cfg.env.num_envs, -1, *_obs.shape[-2:]) - step_data[k] = _obs[np.newaxis] - next_obs[k] = _obs - - if cfg.metric.log_level > 0 and "final_info" in info: - for i, agent_ep_info in enumerate(info["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and "Rewards/rew_avg" in aggregator: - aggregator.update("Rewards/rew_avg", ep_rew) - if aggregator and "Game/ep_len_avg" in aggregator: - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + # Single environment step + obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + truncated_envs = np.nonzero(truncated)[0] + if len(truncated_envs) > 0: + real_next_obs = { + k: torch.empty( + len(truncated_envs), + *observation_space[k].shape, + dtype=torch.float32, + device=device, + ) + for k in obs_keys + } + for i, truncated_env in enumerate(truncated_envs): + for k, v in info["final_observation"][truncated_env].items(): + torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) + if k in cfg.algo.cnn_keys.encoder: + torch_v = torch_v.view(-1, *v.shape[-2:]) + torch_v = torch_v / 255.0 - 0.5 + real_next_obs[k][i] = torch_v + vals = agent.get_value(real_next_obs) + # _, _, _, vals = agent(real_next_obs) + rewards[truncated_envs] += vals.cpu().numpy().reshape(rewards[truncated_envs].shape) + dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards = rewards.reshape(cfg.env.num_envs, -1) + + # Update the step data + step_data["dones"] = dones[np.newaxis] + step_data["values"] = values.cpu().numpy()[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["logprobs"] = logprobs.cpu().numpy()[np.newaxis] + step_data["rewards"] = rewards[np.newaxis] + if cfg.buffer.memmap: + step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) + step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) + + # Append data to buffer + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Update the observation and dones + next_obs = {} + for k in obs_keys: + _obs = obs[k] + if k in cfg.algo.cnn_keys.encoder: + _obs = _obs.reshape(cfg.env.num_envs, -1, *_obs.shape[-2:]) + step_data[k] = _obs[np.newaxis] + next_obs[k] = _obs + + if cfg.metric.log_level > 0 and "final_info" in info: + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and "Rewards/rew_avg" in aggregator: + aggregator.update("Rewards/rew_avg", ep_rew) + if aggregator and "Game/ep_len_avg" in aggregator: + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Transform the data into PyTorch Tensors local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy) # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) - with torch.no_grad(): + with torch.inference_mode(): normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - _, _, _, next_values = player_agent(torch_obs) + next_values = agent.get_value(torch_obs) + # _, _, _, next_values = agent(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], @@ -446,7 +446,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player_agent, fabric, cfg, log_dir) + test_agent = unwrap_fabric(agent) + test(test_agent, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 94c317ba..37f8d9cc 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -12,7 +12,7 @@ from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy -from lightning.fabric.strategies.single_device import SingleDeviceStrategy +from lightning.fabric.wrappers import _FabricModule from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, RandomSampler from torchmetrics import SumMetric @@ -92,12 +92,7 @@ def player( "is_continuous": is_continuous, } agent = PPOAgent(**agent_args).to(device) - agent = SingleDeviceStrategy( - device=fabric.device, - accelerator=fabric.accelerator, - checkpoint_io=fabric._connector.checkpoint_io, - precision=fabric._precision, - ).setup_module(agent) + agent = _FabricModule(agent, precision=fabric._precision) if fabric.is_global_zero: save_configs(cfg, log_dir) diff --git a/sheeprl/algos/ppo/utils.py b/sheeprl/algos/ppo/utils.py index 21775046..9098b4a6 100644 --- a/sheeprl/algos/ppo/utils.py +++ b/sheeprl/algos/ppo/utils.py @@ -41,10 +41,11 @@ def test(agent: PPOAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): while not done: # Act greedly through the environment + actions = agent.get_greedy_actions(obs) if agent.is_continuous: - actions = torch.cat(agent.get_greedy_actions(obs), dim=-1) + actions = torch.cat(actions, dim=-1) else: - actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1) + actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) # Single environment step o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) From ad6e0000baf7192c3d626cc29e85bdac2f13855c Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 19 Mar 2024 17:24:35 +0100 Subject: [PATCH 03/26] Add detach_actor + fix creation of testing agent --- sheeprl/algos/ppo/agent.py | 6 +++++- sheeprl/algos/ppo/ppo.py | 2 ++ sheeprl/algos/ppo/utils.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index 95964672..bae09648 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -87,6 +87,7 @@ def __init__( screen_size: int, distribution_cfg: Dict[str, Any], is_continuous: bool = False, + detach_actor: bool = False, ): super().__init__() self.is_continuous = is_continuous @@ -148,13 +149,16 @@ def __init__( else: actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) + self.detach_actor = detach_actor def forward( self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None ) -> Tuple[Sequence[Tensor], Tensor, Tensor, Tensor]: feat = self.feature_extractor(obs) - actor_out: List[Tensor] = self.actor(feat.detach()) values = self.critic(feat) + if self.detach_actor: + feat = feat.detach() + actor_out: List[Tensor] = self.actor(feat) if self.is_continuous: mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) std = log_std.exp() diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index c354ad0e..a63b52c1 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -447,6 +447,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: test_agent = unwrap_fabric(agent) + test_agent.feature_extractor = _FabricModule(test_agent.feature_extractor, fabric._precision) + test_agent.actor = _FabricModule(test_agent.actor, fabric._precision) test(test_agent, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/ppo/utils.py b/sheeprl/algos/ppo/utils.py index 9098b4a6..06ac2b8d 100644 --- a/sheeprl/algos/ppo/utils.py +++ b/sheeprl/algos/ppo/utils.py @@ -23,7 +23,7 @@ @torch.no_grad() -def test(agent: PPOAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(agent: PPOAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False From b9df6a9978520676dcfe3149bc740d8894356391 Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 19 Mar 2024 17:25:51 +0100 Subject: [PATCH 04/26] FIx SAC to use the correct precision plugin --- sheeprl/algos/sac/agent.py | 8 ++++++++ sheeprl/algos/sac/sac.py | 8 +++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index a6687617..6d877365 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -238,6 +238,11 @@ def actor(self, actor: Union[SACActor, _FabricModule]) -> None: def qfs_target(self) -> nn.ModuleList: return self._qfs_target + @qfs_target.setter + def qfs_target(self, qfs_target: nn.ModuleList) -> None: + self._qfs_target = qfs_target + return + @property def alpha(self) -> float: return self._log_alpha.exp().item() @@ -305,5 +310,8 @@ def build_agent( agent.load_state_dict(agent_state) agent.actor = fabric.setup_module(agent.actor) agent.critics = [fabric.setup_module(critic) for critic in agent.critics] + agent.qfs_target = nn.ModuleList( + [_FabricModule(target, precision=fabric._precision) for target in agent._qfs_target] + ) return agent diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 52521b1c..8e8c8524 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -11,6 +11,7 @@ import torch from lightning.fabric import Fabric from lightning.fabric.plugins.collectives.collective import CollectibleGroup +from lightning.fabric.wrappers import _FabricModule from torch import Tensor from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler @@ -146,6 +147,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): agent = build_agent( fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) + actor = _FabricModule(agent.actor.module, precision=fabric._precision) # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters(), _convert_="all") @@ -239,9 +241,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = envs.action_space.sample() else: # Sample an action given the observation received by the environment - with torch.no_grad(): + with torch.inference_mode(): torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) - actions, _ = agent.actor.module(torch_obs) + actions, _ = actor(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions) next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) @@ -389,7 +391,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(agent.actor.module, fabric, cfg, log_dir) + test(actor, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac.utils import log_models From 4712fcb73330c18d4ddd4471da2dc83a828abd3c Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 21 Mar 2024 11:27:05 +0100 Subject: [PATCH 05/26] Add get_single_devie_fabric method --- sheeprl/algos/ppo/agent.py | 37 ++++++++++++++++++++---------- sheeprl/algos/ppo/ppo.py | 17 +++++--------- sheeprl/algos/ppo/ppo_decoupled.py | 35 ++++++++++++++-------------- sheeprl/algos/sac/sac.py | 5 ++-- sheeprl/algos/sac/sac_decoupled.py | 12 ++++++---- sheeprl/utils/fabric.py | 23 +++++++++++++++++++ 6 files changed, 81 insertions(+), 48 deletions(-) create mode 100644 sheeprl/utils/fabric.py diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index bae09648..03d1aed9 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from math import prod from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -12,6 +13,7 @@ from sheeprl.models.models import MLP, MultiEncoder, NatureCNN from sheeprl.utils.distribution import OneHotCategoricalValidateArgs +from sheeprl.utils.fabric import get_single_device_fabric class CNNEncoder(nn.Module): @@ -87,7 +89,6 @@ def __init__( screen_size: int, distribution_cfg: Dict[str, Any], is_continuous: bool = False, - detach_actor: bool = False, ): super().__init__() self.is_continuous = is_continuous @@ -149,15 +150,12 @@ def __init__( else: actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) - self.detach_actor = detach_actor def forward( - self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None + self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None, greedy: bool = False ) -> Tuple[Sequence[Tensor], Tensor, Tensor, Tensor]: feat = self.feature_extractor(obs) values = self.critic(feat) - if self.detach_actor: - feat = feat.detach() actor_out: List[Tensor] = self.actor(feat) if self.is_continuous: mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) @@ -168,7 +166,10 @@ def forward( validate_args=self.distribution_cfg.validate_args, ) if actions is None: - actions = normal.sample() + if greedy: + actions = mean + else: + actions = normal.sample() else: # always composed by a tuple of one element containing all the # continuous actions @@ -189,7 +190,10 @@ def forward( ) actions_entropies.append(actions_dist[-1].entropy()) if should_append: - actions.append(actions_dist[-1].sample()) + if greedy: + actions.append(actions_dist[-1].mode) + else: + actions.append(actions_dist[-1].sample()) actions_logprobs.append(actions_dist[-1].log_prob(actions[i])) return ( tuple(actions), @@ -223,7 +227,7 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, agent_state: Optional[Dict[str, Tensor]] = None, -) -> PPOAgent: +) -> Tuple[PPOAgent, PPOAgent]: agent = PPOAgent( actions_dim=actions_dim, obs_space=obs_space, @@ -238,7 +242,16 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) - agent.feature_extractor = fabric.setup_module(agent.feature_extractor) - agent.critic = fabric.setup_module(agent.critic) - agent.actor = fabric.setup_module(agent.actor) - return agent + player_agent = copy.deepcopy(agent) + + # Setup training agent + agent = fabric.setup_module(agent) + + # Setup player agent + fabric_player = get_single_device_fabric(fabric) + player_agent = fabric_player.setup_module(player_agent) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.parameters(), player_agent.parameters()): + player_p.data = agent_p.data + return agent, player_agent diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index a63b52c1..9d73f41d 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -24,7 +24,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, save_configs, unwrap_fabric +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, save_configs def train( @@ -170,7 +170,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) ) # Create the actor and critic models - agent = build_agent( + agent, player = build_agent( fabric, actions_dim, is_continuous, @@ -275,7 +275,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = { k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys } - actions, logprobs, _, values = agent(torch_obs) + actions, logprobs, _, values = player(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: @@ -302,8 +302,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_v = torch_v.view(-1, *v.shape[-2:]) torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v - vals = agent.get_value(real_next_obs) - # _, _, _, vals = agent(real_next_obs) + _, _, _, vals = player(real_next_obs) rewards[truncated_envs] += vals.cpu().numpy().reshape(rewards[truncated_envs].shape) dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) @@ -348,8 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.inference_mode(): normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - next_values = agent.get_value(torch_obs) - # _, _, _, next_values = agent(torch_obs) + _, _, _, next_values = player(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], @@ -446,10 +444,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test_agent = unwrap_fabric(agent) - test_agent.feature_extractor = _FabricModule(test_agent.feature_extractor, fabric._precision) - test_agent.actor = _FabricModule(test_agent.actor, fabric._precision) - test(test_agent, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 37f8d9cc..c506f66e 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -12,7 +12,6 @@ from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy -from lightning.fabric.wrappers import _FabricModule from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, RandomSampler from torchmetrics import SumMetric @@ -22,6 +21,7 @@ from sheeprl.algos.ppo.utils import normalize_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -29,7 +29,7 @@ from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, save_configs -@torch.no_grad() +@torch.inference_mode() def player( fabric: Fabric, cfg: Dict[str, Any], world_collective: TorchCollective, player_trainer_collective: TorchCollective ): @@ -92,7 +92,8 @@ def player( "is_continuous": is_continuous, } agent = PPOAgent(**agent_args).to(device) - agent = _FabricModule(agent, precision=fabric._precision) + fabric_player = get_single_device_fabric(fabric) + agent = fabric_player.setup_module(agent, move_to_device=False) if fabric.is_global_zero: save_configs(cfg, log_dir) @@ -192,18 +193,17 @@ def player( # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): - # Sample an action given the observation received by the environment - normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - torch_obs = { - k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys - } - actions, logprobs, _, values = agent(torch_obs) - if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() - else: - real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() - actions = torch.cat(actions, -1).cpu().numpy() + # Sample an action given the observation received by the environment + normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) + torch_obs = { + k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys + } + actions, logprobs, _, values = agent(torch_obs) + if is_continuous: + real_actions = torch.cat(actions, -1).cpu().numpy() + else: + real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + actions = torch.cat(actions, -1).cpu().numpy() # Single environment step obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) @@ -225,9 +225,8 @@ def player( torch_v = torch_v.view(-1, *v.shape[-2:]) torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v - with torch.no_grad(): - vals = agent.get_value(real_next_obs).cpu().numpy() - rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) + _, _, _, vals = agent(real_next_obs) + rewards[truncated_envs] += vals.cpu().numpy().reshape(rewards[truncated_envs].shape) dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 8e8c8524..73f1ca43 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -11,7 +11,6 @@ import torch from lightning.fabric import Fabric from lightning.fabric.plugins.collectives.collective import CollectibleGroup -from lightning.fabric.wrappers import _FabricModule from torch import Tensor from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler @@ -23,6 +22,7 @@ from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -147,7 +147,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): agent = build_agent( fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) - actor = _FabricModule(agent.actor.module, precision=fabric._precision) + fabric_player = get_single_device_fabric(fabric) + actor = fabric_player.setup_module(agent.actor.module) # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters(), _convert_="all") diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 8eacf018..44a23569 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -21,6 +21,7 @@ from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -28,7 +29,7 @@ from sheeprl.utils.utils import save_configs -@torch.no_grad() +@torch.inference_mode() def player( fabric: Fabric, cfg: Dict[str, Any], world_collective: TorchCollective, player_trainer_collective: TorchCollective ): @@ -96,6 +97,8 @@ def player( action_low=action_space.low, action_high=action_space.high, ).to(device) + fabric_player = get_single_device_fabric(fabric) + actor = fabric_player.setup_module(actor, move_to_device=False) flattened_parameters = torch.empty_like( torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()), device=device ) @@ -177,10 +180,9 @@ def player( actions = envs.action_space.sample() else: # Sample an action given the observation received by the environment - with torch.no_grad(): - torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) - actions, _ = actor(torch_obs) - actions = actions.cpu().numpy() + torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) + actions, _ = actor(torch_obs) + actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions) next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) diff --git a/sheeprl/utils/fabric.py b/sheeprl/utils/fabric.py new file mode 100644 index 00000000..fb7b9242 --- /dev/null +++ b/sheeprl/utils/fabric.py @@ -0,0 +1,23 @@ +from lightning.fabric import Fabric +from lightning.fabric.accelerators import XLAAccelerator +from lightning.fabric.strategies import SingleDeviceStrategy, SingleDeviceXLAStrategy + + +def get_single_device_fabric(fabric: Fabric) -> Fabric: + """Get a single device fabric. The returned fabric will share the same accelerator, + precision and device as the input fabric. + + Args: + fabric (Fabric): The fabric to use as a base. + + Returns: + Fabric: A new fabric with the same device, precision and accelerator as the input fabric. + """ + strategy_cls = SingleDeviceXLAStrategy if isinstance(fabric.accelerator, XLAAccelerator) else SingleDeviceStrategy + strategy = strategy_cls( + device=fabric.device, + accelerator=fabric.accelerator, + checkpoint_io=None, + precision=fabric._precision, + ) + return Fabric(strategy=strategy) From db2f021930440e3c521fc4a39ad3af137c9d4ce1 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 10:09:30 +0100 Subject: [PATCH 06/26] Fix DrOQ agent to handle fp16 --- sheeprl/algos/droq/agent.py | 8 ++++++++ sheeprl/algos/droq/droq.py | 7 +++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py index 7a02da77..1c89554f 100644 --- a/sheeprl/algos/droq/agent.py +++ b/sheeprl/algos/droq/agent.py @@ -154,6 +154,11 @@ def actor(self, actor: Union[SACActor, _FabricModule]) -> None: def qfs_target(self) -> nn.ModuleList: return self._qfs_target + @qfs_target.setter + def qfs_target(self, qfs_target: nn.ModuleList) -> None: + self._qfs_target = qfs_target + return + @property def alpha(self) -> float: return self._log_alpha.exp().item() @@ -237,5 +242,8 @@ def build_agent( agent.load_state_dict(agent_state) agent.actor = fabric.setup_module(agent.actor) agent.critics = [fabric.setup_module(critic) for critic in agent.critics] + agent.qfs_target = nn.ModuleList( + [_FabricModule(target, precision=fabric._precision) for target in agent._qfs_target] + ) return agent diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 57510355..2d7f8c1a 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -21,6 +21,7 @@ from sheeprl.algos.sac.sac import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -194,6 +195,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): agent = build_agent( fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) + fabric_player = get_single_device_fabric(fabric) + actor = fabric_player.setup_module(agent.actor.module) # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters(), _convert_="all") @@ -285,7 +288,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): # Sample an action given the observation received by the environment - actions, _ = agent.actor.module(torch.from_numpy(obs).to(device)) + actions, _ = actor(torch.from_numpy(obs).to(device)) actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) dones = np.logical_or(dones, truncated) @@ -385,7 +388,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(agent.actor.module, fabric, cfg, log_dir) + test(actor, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac.utils import log_models From 2caf03b69500552ebe6cb11628070adff338108b Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 10:09:45 +0100 Subject: [PATCH 07/26] Fix SACAE agent to handle fp16 --- sheeprl/algos/sac_ae/agent.py | 14 +++++++++++++- sheeprl/algos/sac_ae/sac_ae.py | 9 ++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index 727e0cb3..46ac1fc4 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -71,6 +71,9 @@ def conv_output_shape(self) -> Size: return self._conv_output_shape def forward(self, obs: Dict[str, Tensor], *, detach_encoder_features: bool = False, **kwargs) -> Tensor: + dtypes = [obs[k].dtype for k in self.keys] + if dtypes.count(dtypes[0]) != len(dtypes): + raise ValueError("All the tensors must have the same dtype: {}".format(dtypes)) x = torch.cat([obs[k] for k in self.keys], dim=-3) x = cnn_forward(self.model, x, x.shape[-3:], (-1,)) if detach_encoder_features: @@ -102,7 +105,10 @@ def __init__( self.input_dim = input_dim def forward(self, obs: Dict[str, Tensor], *args, detach_encoder_features: bool = False, **kwargs) -> Tensor: - x = torch.cat([obs[k] for k in self.keys], dim=-1).type(torch.float32) + dtypes = [obs[k].dtype for k in self.keys] + if dtypes.count(dtypes[0]) != len(dtypes): + raise ValueError("All the tensors must have the same dtype: {}".format(dtypes)) + x = torch.cat([obs[k] for k in self.keys], dim=-1) x = self.model(x) if detach_encoder_features: x = x.detach() @@ -406,6 +412,11 @@ def actor(self, actor: Union[SACAEContinuousActor, _FabricModule]) -> None: def critic_target(self) -> SACAECritic: return self._critic_target + @critic_target.setter + def critic_target(self, critic_target: SACAECritic | _FabricModule) -> None: + self._critic_target = critic_target + return + @property def alpha(self) -> float: return self._log_alpha.exp().item() @@ -561,5 +572,6 @@ def build_agent( decoder = fabric.setup_module(decoder) agent.actor = fabric.setup_module(agent.actor) agent.critic = fabric.setup_module(agent.critic) + agent.critic_target = _FabricModule(agent.critic_target, precision=fabric._precision) return agent, encoder, decoder diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index b59521fb..09c47b56 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -26,6 +26,7 @@ from sheeprl.data.buffers import ReplayBuffer from sheeprl.models.models import MultiDecoder, MultiEncoder from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -207,6 +208,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["encoder"] if cfg.checkpoint.resume_from else None, state["decoder"] if cfg.checkpoint.resume_from else None, ) + fabric_player = get_single_device_fabric(fabric) + actor = fabric_player.setup_module(agent.actor.module) # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.critic.parameters(), _convert_="all") @@ -313,10 +316,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if update < learning_starts: actions = envs.action_space.sample() else: - with torch.no_grad(): + with torch.inference_mode(): normalized_obs = {k: v / 255 if k in cfg.algo.cnn_keys.encoder else v for k, v in obs.items()} torch_obs = {k: torch.from_numpy(v).to(device).float() for k, v in normalized_obs.items()} - actions, _ = agent.actor.module(torch_obs) + actions, _ = actor(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) dones = np.logical_or(dones, truncated) @@ -471,7 +474,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(agent.actor.module, fabric, cfg, log_dir) + test(actor, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac_ae.utils import log_models From 8191e8e5841c41581179354b761351563345fd1f Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 12:50:50 +0100 Subject: [PATCH 08/26] Fix PPO recurrent to handle fp16 --- sheeprl/algos/ppo_recurrent/agent.py | 47 ++++-- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 164 +++++++++---------- sheeprl/models/models.py | 1 + sheeprl/utils/fabric.py | 7 +- 4 files changed, 122 insertions(+), 97 deletions(-) diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py index ac914e07..438300bc 100644 --- a/sheeprl/algos/ppo_recurrent/agent.py +++ b/sheeprl/algos/ppo_recurrent/agent.py @@ -1,3 +1,4 @@ +import copy from math import prod from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -8,9 +9,10 @@ from torch import Tensor from torch.distributions import Independent, Normal -from sheeprl.algos.ppo.agent import CNNEncoder, MLPEncoder +from sheeprl.algos.ppo.agent import CNNEncoder, MLPEncoder, PPOActor from sheeprl.models.models import MLP, MultiEncoder from sheeprl.utils.distribution import OneHotCategoricalValidateArgs +from sheeprl.utils.fabric import get_single_device_fabric class RecurrentModel(nn.Module): @@ -26,9 +28,11 @@ def __init__( activation=eval(pre_rnn_mlp_cfg.activation), layer_args={"bias": pre_rnn_mlp_cfg.bias}, norm_layer=[nn.LayerNorm] if pre_rnn_mlp_cfg.layer_norm else None, - norm_args=[{"normalized_shape": pre_rnn_mlp_cfg.dense_units, "eps": 1e-3}] - if pre_rnn_mlp_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": pre_rnn_mlp_cfg.dense_units, "eps": 1e-3}] + if pre_rnn_mlp_cfg.layer_norm + else None + ), ) else: self._pre_mlp = nn.Identity() @@ -45,9 +49,11 @@ def __init__( activation=eval(post_rnn_mlp_cfg.activation), layer_args={"bias": post_rnn_mlp_cfg.bias}, norm_layer=[nn.LayerNorm] if post_rnn_mlp_cfg.layer_norm else None, - norm_args=[{"normalized_shape": post_rnn_mlp_cfg.dense_units, "eps": 1e-3}] - if post_rnn_mlp_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": post_rnn_mlp_cfg.dense_units, "eps": 1e-3}] + if post_rnn_mlp_cfg.layer_norm + else None + ), ) self._output_dim = post_rnn_mlp_cfg.dense_units else: @@ -165,6 +171,7 @@ def __init__( self.actor_heads = nn.ModuleList( [nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim] ) + self.actor = PPOActor(self.actor_backbone, self.actor_heads, is_continuous) # Initial recurrent states for both the actor and critic rnn self._initial_states: Tensor = self.reset_hidden_states() @@ -245,8 +252,7 @@ def get_sampled_actions( ) def get_pre_dist(self, input: Tensor) -> Union[Tuple[Tensor, ...], Tuple[Tensor, Tensor]]: - features = self.actor_backbone(input) - pre_dist: List[Tensor] = [head(features) for head in self.actor_heads] + pre_dist: List[Tensor] = self.actor(input) if self.is_continuous: mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1) std = log_std.exp() @@ -296,7 +302,7 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, agent_state: Optional[Dict[str, Tensor]] = None, -) -> RecurrentPPOAgent: +) -> Tuple[RecurrentPPOAgent, RecurrentPPOAgent]: agent = RecurrentPPOAgent( actions_dim=actions_dim, obs_space=obs_space, @@ -314,6 +320,23 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) - agent = fabric.setup_module(agent) + player_agent = copy.deepcopy(agent) - return agent + # Setup training agent + agent.feature_extractor = fabric.setup_module(agent.feature_extractor) + agent.rnn = fabric.setup_module(agent.rnn) + agent.critic = fabric.setup_module(agent.critic) + agent.actor = fabric.setup_module(agent.actor) + + # Setup player agent + fabric_player = get_single_device_fabric(fabric) + player_agent.feature_extractor = fabric_player.setup_module(player_agent.feature_extractor) + player_agent.rnn = fabric_player.setup_module(player_agent.rnn) + player_agent.critic = fabric_player.setup_module(player_agent.critic) + player_agent.actor = fabric_player.setup_module(player_agent.actor) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.parameters(), player_agent.parameters()): + player_p.data = agent_p.data + + return agent, player_agent diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index ab71a365..70a1481f 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -179,7 +179,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Define the agent and the optimizer - agent = build_agent( + agent, player = build_agent( fabric, actions_dim, is_continuous, @@ -275,18 +275,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_prev_actions = torch.zeros(1, cfg.env.num_envs, sum(actions_dim), device=device, dtype=torch.float32) for update in range(start_step, num_updates + 1): - for _ in range(0, cfg.algo.rollout_steps): - policy_step += cfg.env.num_envs * world_size + with torch.inference_mode(): + for _ in range(0, cfg.algo.rollout_steps): + policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment # [Seq_len, Batch_size, D] --> [1, num_envs, D] normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(v, device=device).float() for k, v in normalized_obs.items()} - actions, logprobs, _, values, states = agent.module( + actions, logprobs, _, values, states = player( torch_obs, prev_actions=torch_prev_actions, prev_states=prev_states ) if is_continuous: @@ -296,90 +296,89 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_actions = torch.cat(actions, dim=-1) actions = torch_actions.cpu().numpy() - # Single environment step - next_obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) - truncated_envs = np.nonzero(truncated)[0] - if len(truncated_envs) > 0: - real_next_obs = { - k: torch.empty( - 1, - len(truncated_envs), - *observation_space[k].shape, - dtype=torch.float32, - device=device, - ) - for k in obs_keys - } # [Seq_len, Batch_size, D] --> [1, num_truncated_envs, D] - for i, truncated_env in enumerate(truncated_envs): - for k, v in info["final_observation"][truncated_env].items(): - torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - torch_v = torch_v.view(1, 1, -1, *torch_v.shape[-2:]) / 255.0 - 0.5 - real_next_obs[k][0, i] = torch_v - with torch.no_grad(): - feat = agent.module.feature_extractor(real_next_obs) - rnn_out, _ = agent.module.rnn( + # Single environment step + next_obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + truncated_envs = np.nonzero(truncated)[0] + if len(truncated_envs) > 0: + real_next_obs = { + k: torch.empty( + 1, + len(truncated_envs), + *observation_space[k].shape, + dtype=torch.float32, + device=device, + ) + for k in obs_keys + } # [Seq_len, Batch_size, D] --> [1, num_truncated_envs, D] + for i, truncated_env in enumerate(truncated_envs): + for k, v in info["final_observation"][truncated_env].items(): + torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) + if k in cfg.algo.cnn_keys.encoder: + torch_v = torch_v.view(1, 1, -1, *torch_v.shape[-2:]) / 255.0 - 0.5 + real_next_obs[k][0, i] = torch_v + feat = player.feature_extractor(real_next_obs) + rnn_out, _ = player.rnn( torch.cat((feat, torch_actions[:, truncated_envs, :]), dim=-1), tuple(s[:, truncated_envs, ...] for s in states), ) - vals = agent.module.get_values(rnn_out).view(rewards[truncated_envs].shape).cpu().numpy() + vals = player.get_values(rnn_out).view(rewards[truncated_envs].shape).cpu().numpy() rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) - rewards = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) - - step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1) - step_data["values"] = values.cpu().numpy().reshape(1, cfg.env.num_envs, -1) - step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) - step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1) - step_data["logprobs"] = logprobs.cpu().numpy() - step_data["prev_hx"] = prev_states[0].cpu().numpy().reshape(1, cfg.env.num_envs, -1) - step_data["prev_cx"] = prev_states[1].cpu().numpy().reshape(1, cfg.env.num_envs, -1) - step_data["prev_actions"] = prev_actions.reshape(1, cfg.env.num_envs, -1) - if cfg.buffer.memmap: - step_data["returns"] = np.zeros_like(rewards) - step_data["advantages"] = np.zeros_like(rewards) - - # Append data to buffer - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Update actions - prev_actions = (1 - dones) * actions - torch_prev_actions = torch.from_numpy(prev_actions).to(device).float() - - # Update the observation - obs = next_obs - for k in obs_keys: - obs[k] = obs[k][np.newaxis] - if k in cfg.algo.cnn_keys.encoder: - obs[k] = obs[k].reshape(1, cfg.env.num_envs, -1, *obs[k].shape[-2:]) - step_data[k] = obs[k] - - # Reset the states if the episode is done - if cfg.algo.reset_recurrent_state_on_done: - prev_states = tuple([(1 - torch.as_tensor(dones, device=device)) * s for s in states]) - else: - prev_states = states - - if cfg.metric.log_level > 0 and "final_info" in info: - for i, agent_ep_info in enumerate(info["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + dones = np.logical_or(dones, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) + rewards = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + + step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1) + step_data["values"] = values.cpu().numpy().reshape(1, cfg.env.num_envs, -1) + step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) + step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1) + step_data["logprobs"] = logprobs.cpu().numpy() + step_data["prev_hx"] = prev_states[0].cpu().numpy().reshape(1, cfg.env.num_envs, -1) + step_data["prev_cx"] = prev_states[1].cpu().numpy().reshape(1, cfg.env.num_envs, -1) + step_data["prev_actions"] = prev_actions.reshape(1, cfg.env.num_envs, -1) + if cfg.buffer.memmap: + step_data["returns"] = np.zeros_like(rewards) + step_data["advantages"] = np.zeros_like(rewards) + + # Append data to buffer + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Update actions + prev_actions = (1 - dones) * actions + torch_prev_actions = torch.from_numpy(prev_actions).to(device).float() + + # Update the observation + obs = next_obs + for k in obs_keys: + obs[k] = obs[k][np.newaxis] + if k in cfg.algo.cnn_keys.encoder: + obs[k] = obs[k].reshape(1, cfg.env.num_envs, -1, *obs[k].shape[-2:]) + step_data[k] = obs[k] + + # Reset the states if the episode is done + if cfg.algo.reset_recurrent_state_on_done: + prev_states = tuple([(1 - torch.as_tensor(dones, device=device)) * s for s in states]) + else: + prev_states = states + + if cfg.metric.log_level > 0 and "final_info" in info: + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Transform the data into PyTorch Tensors local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy) # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) - with torch.no_grad(): + with torch.inference_mode(): normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(v, device=device).float() for k, v in normalized_obs.items()} - feat = agent.module.feature_extractor(torch_obs) - rnn_out, _ = agent.module.rnn(torch.cat((feat, torch_actions), dim=-1), states) - next_values = agent.module.get_values(rnn_out) + feat = player.feature_extractor(torch_obs) + rnn_out, _ = player.rnn(torch.cat((feat, torch_actions), dim=-1), states) + next_values = player.get_values(rnn_out) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], @@ -425,7 +424,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): sequences[k].extend(seq) # Regardless of the key, the shapes are the same lengths.extend([s.shape[0] for s in seq]) - else: sequences = episodes @@ -509,7 +507,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(agent.module, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/models/models.py b/sheeprl/models/models.py index df774c13..089c45ce 100644 --- a/sheeprl/models/models.py +++ b/sheeprl/models/models.py @@ -1,6 +1,7 @@ """ Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py """ + import warnings from math import prod from typing import Dict, Optional, Sequence, Union, no_type_check diff --git a/sheeprl/utils/fabric.py b/sheeprl/utils/fabric.py index fb7b9242..bd6a60cf 100644 --- a/sheeprl/utils/fabric.py +++ b/sheeprl/utils/fabric.py @@ -5,13 +5,16 @@ def get_single_device_fabric(fabric: Fabric) -> Fabric: """Get a single device fabric. The returned fabric will share the same accelerator, - precision and device as the input fabric. + precision and device as the input fabric. This is useful when you want to create a new + fabric with the same device as the input fabric, but with a strategy running on a single + device. Args: fabric (Fabric): The fabric to use as a base. Returns: - Fabric: A new fabric with the same device, precision and accelerator as the input fabric. + Fabric: A new fabric with the same device, precision and accelerator as the input fabric but with + a single-device strategy. """ strategy_cls = SingleDeviceXLAStrategy if isinstance(fabric.accelerator, XLAAccelerator) else SingleDeviceStrategy strategy = strategy_cls( From 68f6805ec7c39c4070c752a3f7e1e2c2171f3859 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 12:54:26 +0100 Subject: [PATCH 09/26] Fix Join context --- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 70a1481f..b16a5459 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -43,7 +43,18 @@ def train( batch_size = batch_size if batch_size > 0 else num_sequences else: batch_size = 1 - with Join([agent._forward_module]) if fabric.world_size > 1 else nullcontext(): + with ( + Join( + [ + agent.feature_extractor._forward_module, + agent.rnn._forward_module, + agent.actor._forward_module, + agent.critic._forward_module, + ] + ) + if fabric.world_size > 1 + else nullcontext() + ): for _ in range(cfg.algo.update_epochs): sampler = BatchSampler( RandomSampler(range(num_sequences)), From 39513db12d5ce3d4a19a89558f5ea449c160542a Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 14:23:44 +0100 Subject: [PATCH 10/26] Fix PPO evaluate --- sheeprl/algos/ppo/agent.py | 8 ++++---- sheeprl/algos/ppo/evaluate.py | 2 +- sheeprl/algos/ppo_recurrent/agent.py | 14 +++++++------- sheeprl/algos/ppo_recurrent/evaluate.py | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index 03d1aed9..dac3e758 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -242,16 +242,16 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) - player_agent = copy.deepcopy(agent) + player = copy.deepcopy(agent) # Setup training agent agent = fabric.setup_module(agent) # Setup player agent fabric_player = get_single_device_fabric(fabric) - player_agent = fabric_player.setup_module(player_agent) + player = fabric_player.setup_module(player) # Tie weights between the agent and the player - for agent_p, player_p in zip(agent.parameters(), player_agent.parameters()): + for agent_p, player_p in zip(agent.parameters(), player.parameters()): player_p.data = agent_p.data - return agent, player_agent + return agent, player diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py index 6d66c01a..4725c2bd 100644 --- a/sheeprl/algos/ppo/evaluate.py +++ b/sheeprl/algos/ppo/evaluate.py @@ -49,7 +49,7 @@ def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + agent, _ = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py index 438300bc..84a64c5d 100644 --- a/sheeprl/algos/ppo_recurrent/agent.py +++ b/sheeprl/algos/ppo_recurrent/agent.py @@ -320,7 +320,7 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) - player_agent = copy.deepcopy(agent) + player = copy.deepcopy(agent) # Setup training agent agent.feature_extractor = fabric.setup_module(agent.feature_extractor) @@ -330,13 +330,13 @@ def build_agent( # Setup player agent fabric_player = get_single_device_fabric(fabric) - player_agent.feature_extractor = fabric_player.setup_module(player_agent.feature_extractor) - player_agent.rnn = fabric_player.setup_module(player_agent.rnn) - player_agent.critic = fabric_player.setup_module(player_agent.critic) - player_agent.actor = fabric_player.setup_module(player_agent.actor) + player.feature_extractor = fabric_player.setup_module(player.feature_extractor) + player.rnn = fabric_player.setup_module(player.rnn) + player.critic = fabric_player.setup_module(player.critic) + player.actor = fabric_player.setup_module(player.actor) # Tie weights between the agent and the player - for agent_p, player_p in zip(agent.parameters(), player_agent.parameters()): + for agent_p, player_p in zip(agent.parameters(), player.parameters()): player_p.data = agent_p.data - return agent, player_agent + return agent, player diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py index 0c5a0ed1..12f57dba 100644 --- a/sheeprl/algos/ppo_recurrent/evaluate.py +++ b/sheeprl/algos/ppo_recurrent/evaluate.py @@ -49,5 +49,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + agent, _ = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) test(agent, fabric, cfg, log_dir) From 491bc569e7e9332562b4472a558db3baaee6c789 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 14:35:03 +0100 Subject: [PATCH 11/26] Fix Dreamer-V1 player to handle fp16 --- sheeprl/algos/dreamer_v1/agent.py | 18 ++++++++++-------- sheeprl/algos/dreamer_v1/dreamer_v1.py | 17 +++++++++++++---- sheeprl/algos/dreamer_v1/evaluate.py | 2 +- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 8c6bc351..5d127b64 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -17,6 +17,7 @@ from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor from sheeprl.algos.dreamer_v2.agent import MLPDecoder, MLPEncoder from sheeprl.models.models import MLP, MultiDecoder, MultiEncoder +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.utils import init_weights # In order to use the hydra.utils.get_class method, in this way the user can @@ -221,6 +222,7 @@ class PlayerDV1(nn.Module): """The model of the DreamerV1 player. Args: + fabric (Fabric): the fabric object. encoder (nn.Module): the encoder. recurrent_model (nn.Module): the recurrent model. representation_model (nn.Module): the representation model. @@ -229,13 +231,13 @@ class PlayerDV1(nn.Module): num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. recurrent_state_size (int): the size of the recurrent state. - device (torch.device): the device to work on. actor_type (str, optional): which actor the player is using ('task' or 'exploration'). Default to None. """ def __init__( self, + fabric: Fabric, encoder: nn.Module, recurrent_model: nn.Module, representation_model: nn.Module, @@ -244,22 +246,22 @@ def __init__( num_envs: int, stochastic_size: int, recurrent_state_size: int, - device: torch.device, actor_type: str | None = None, ) -> None: super().__init__() - self.encoder = encoder - self.recurrent_model = recurrent_model - self.representation_model = representation_model - self.actor = actor - self.device = device + fabric_player = get_single_device_fabric(fabric) + self.encoder = fabric_player.setup_module(encoder) + self.recurrent_model = fabric_player.setup_module(recurrent_model) + self.representation_model = fabric_player.setup_module(representation_model) + self.actor = fabric_player.setup_module(actor) + self.device = fabric_player.device self.actions_dim = actions_dim self.stochastic_size = stochastic_size self.recurrent_state_size = recurrent_state_size self.num_envs = num_envs self.validate_args = self.actor.distribution_cfg.validate_args - self.init_states() self.actor_type = actor_type + self.init_states() def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: """Initialize the states and the actions for the ended environments. diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index a6da7b23..02324494 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -231,7 +231,10 @@ def train( world_model_grads = None if cfg.algo.world_model.clip_gradients is not None and cfg.algo.world_model.clip_gradients > 0: world_model_grads = fabric.clip_gradients( - module=world_model, optimizer=world_optimizer, max_norm=cfg.algo.world_model.clip_gradients + module=world_model, + optimizer=world_optimizer, + max_norm=cfg.algo.world_model.clip_gradients, + error_if_nonfinite=False, ) world_optimizer.step() @@ -337,7 +340,10 @@ def train( actor_grads = None if cfg.algo.actor.clip_gradients is not None and cfg.algo.actor.clip_gradients > 0: actor_grads = fabric.clip_gradients( - module=actor, optimizer=actor_optimizer, max_norm=cfg.algo.actor.clip_gradients + module=actor, + optimizer=actor_optimizer, + max_norm=cfg.algo.actor.clip_gradients, + error_if_nonfinite=False, ) actor_optimizer.step() @@ -362,7 +368,10 @@ def train( critic_grads = None if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0: critic_grads = fabric.clip_gradients( - module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients + module=critic, + optimizer=critic_optimizer, + max_norm=cfg.algo.critic.clip_gradients, + error_if_nonfinite=False, ) critic_optimizer.step() @@ -472,6 +481,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["critic"] if cfg.checkpoint.resume_from else None, ) player = PlayerDV1( + fabric, world_model.encoder.module, world_model.rssm.recurrent_model.module, world_model.rssm.representation_model.module, @@ -480,7 +490,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, ) # Optimizers diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index f69891fa..e33d1f54 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -54,6 +54,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["actor"], ) player = PlayerDV1( + fabric, world_model.encoder.module, world_model.rssm.recurrent_model.module, world_model.rssm.representation_model.module, @@ -62,7 +63,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, ) test(player, fabric, cfg, log_dir, sample_actions=False) From 3294fbd1ead0db531e47091cef68344dc8dce8c9 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 15:13:52 +0100 Subject: [PATCH 12/26] Dreamer-V1 inference mode --- sheeprl/algos/dreamer_v1/agent.py | 12 +-- sheeprl/algos/dreamer_v1/dreamer_v1.py | 140 ++++++++++++------------- 2 files changed, 76 insertions(+), 76 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 5d127b64..d541a2e4 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -249,12 +249,12 @@ def __init__( actor_type: str | None = None, ) -> None: super().__init__() - fabric_player = get_single_device_fabric(fabric) - self.encoder = fabric_player.setup_module(encoder) - self.recurrent_model = fabric_player.setup_module(recurrent_model) - self.representation_model = fabric_player.setup_module(representation_model) - self.actor = fabric_player.setup_module(actor) - self.device = fabric_player.device + single_device_fabric = get_single_device_fabric(fabric) + self.encoder = single_device_fabric.setup_module(encoder) + self.recurrent_model = single_device_fabric.setup_module(recurrent_model) + self.representation_model = single_device_fabric.setup_module(representation_model) + self.actor = single_device_fabric.setup_module(actor) + self.device = single_device_fabric.device self.actions_dim = actions_dim self.stochastic_size = stochastic_size self.recurrent_state_size = recurrent_state_size diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 02324494..4c582609 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -595,26 +595,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.env.wrapper._target_.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.as_tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.env.wrapper._target_.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.as_tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: normalized_obs = {} for k in obs_keys: torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) @@ -632,57 +632,57 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: - if k in cfg.algo.cnn_keys.encoder: - next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) - real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) - step_data[k] = real_next_obs[k][np.newaxis] - - # next_obs becomes the new obs - obs = next_obs - - step_data["dones"] = dones[np.newaxis] - step_data["actions"] = actions[np.newaxis] - step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + for k in obs_keys: - reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated - for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) - # Reset internal agent states - player.init_states(reset_envs=dones_idxes) + if k in cfg.algo.cnn_keys.encoder: + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) + step_data[k] = real_next_obs[k][np.newaxis] + + # next_obs becomes the new obs + obs = next_obs + + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + # Reset internal agent states + player.init_states(reset_envs=dones_idxes) updates_before_training -= 1 From d3f47033273643d7aa89dd6d42283522199ecf36 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 15:14:32 +0100 Subject: [PATCH 13/26] Fix Dreamer-V2 to handle fp16 --- sheeprl/algos/dreamer_v2/agent.py | 110 ++++++++++-------- sheeprl/algos/dreamer_v2/dreamer_v2.py | 148 ++++++++++++------------- sheeprl/algos/dreamer_v2/evaluate.py | 2 +- 3 files changed, 140 insertions(+), 120 deletions(-) diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index c83e46b1..7f80eb76 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -20,6 +20,7 @@ OneHotCategoricalValidateArgs, TruncatedNormal, ) +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward @@ -61,9 +62,9 @@ def __init__( layer_args={"kernel_size": 4, "stride": 2}, activation=activation, norm_layer=[LayerNormChannelLast for _ in range(4)] if layer_norm else None, - norm_args=[{"normalized_shape": (2**i) * channels_multiplier} for i in range(4)] - if layer_norm - else None, + norm_args=( + [{"normalized_shape": (2**i) * channels_multiplier} for i in range(4)] if layer_norm else None + ), ), nn.Flatten(-3, -1), ) @@ -172,12 +173,12 @@ def __init__( ], activation=[activation, activation, activation, None], norm_layer=[LayerNormChannelLast for _ in range(3)] + [None] if layer_norm else None, - norm_args=[ - {"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier} for i in range(self.output_dim[0]) - ] - + [None] - if layer_norm - else None, + norm_args=( + [{"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier} for i in range(self.output_dim[0])] + + [None] + if layer_norm + else None + ), ), ) @@ -743,6 +744,7 @@ class PlayerDV2(nn.Module): The model of the Dreamer_v2 player. Args: + fabric: the fabric of the model. encoder (nn.Module): the encoder. recurrent_model (nn.Module): the recurrent model. representation_model (nn.Module): the representation model. @@ -751,7 +753,6 @@ class PlayerDV2(nn.Module): num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. recurrent_state_size (int): the size of the recurrent state. - device (torch.device): the device to work on. discrete_size (int): the dimension of a single Categorical variable in the stochastic state (prior or posterior). Defaults to 32. @@ -761,6 +762,7 @@ class PlayerDV2(nn.Module): def __init__( self, + fabric: Fabric, encoder: nn.Module, recurrent_model: nn.Module, representation_model: nn.Module, @@ -769,16 +771,16 @@ def __init__( num_envs: int, stochastic_size: int, recurrent_state_size: int, - device: torch.device, discrete_size: int = 32, actor_type: str | None = None, ) -> None: super().__init__() - self.encoder = encoder - self.recurrent_model = recurrent_model - self.representation_model = representation_model - self.actor = actor - self.device = device + fabric_player = get_single_device_fabric(fabric) + self.encoder = fabric_player.setup_module(encoder) + self.recurrent_model = fabric_player.setup_module(recurrent_model) + self.representation_model = fabric_player.setup_module(representation_model) + self.actor = fabric_player.setup_module(actor) + self.device = fabric_player.device self.actions_dim = actions_dim self.stochastic_size = stochastic_size self.discrete_size = discrete_size @@ -871,7 +873,7 @@ def build_agent( actor_state: Optional[Dict[str, Tensor]] = None, critic_state: Optional[Dict[str, Tensor]] = None, target_critic_state: Optional[Dict[str, Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, nn.Module]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule]: """Build the models and wrap them with Fabric. Args: @@ -894,7 +896,7 @@ def build_agent( reward models and the continue model. The actor (_FabricModule). The critic (_FabricModule). - The target critic (nn.Module). + The target critic (_FabricModule). """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -943,9 +945,11 @@ def build_agent( activation=eval(world_model_cfg.representation_model.dense_act), flatten_dim=None, norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None, - norm_args=[{"normalized_shape": world_model_cfg.representation_model.hidden_size}] - if world_model_cfg.representation_model.layer_norm - else None, + norm_args=( + [{"normalized_shape": world_model_cfg.representation_model.hidden_size}] + if world_model_cfg.representation_model.layer_norm + else None + ), ) transition_model = MLP( input_dims=world_model_cfg.recurrent_model.recurrent_state_size, @@ -954,9 +958,11 @@ def build_agent( activation=eval(world_model_cfg.transition_model.dense_act), flatten_dim=None, norm_layer=[nn.LayerNorm] if world_model_cfg.transition_model.layer_norm else None, - norm_args=[{"normalized_shape": world_model_cfg.transition_model.hidden_size}] - if world_model_cfg.transition_model.layer_norm - else None, + norm_args=( + [{"normalized_shape": world_model_cfg.transition_model.hidden_size}] + if world_model_cfg.transition_model.layer_norm + else None + ), ) rssm = RSSM( recurrent_model=recurrent_model.apply(init_weights), @@ -999,15 +1005,19 @@ def build_agent( hidden_sizes=[world_model_cfg.reward_model.dense_units] * world_model_cfg.reward_model.mlp_layers, activation=eval(world_model_cfg.reward_model.dense_act), flatten_dim=None, - norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)] - if world_model_cfg.reward_model.layer_norm - else None, - norm_args=[ - {"normalized_shape": world_model_cfg.reward_model.dense_units} - for _ in range(world_model_cfg.reward_model.mlp_layers) - ] - if world_model_cfg.reward_model.layer_norm - else None, + norm_layer=( + [nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)] + if world_model_cfg.reward_model.layer_norm + else None + ), + norm_args=( + [ + {"normalized_shape": world_model_cfg.reward_model.dense_units} + for _ in range(world_model_cfg.reward_model.mlp_layers) + ] + if world_model_cfg.reward_model.layer_norm + else None + ), ) if world_model_cfg.use_continues: continue_model = MLP( @@ -1016,15 +1026,19 @@ def build_agent( hidden_sizes=[world_model_cfg.discount_model.dense_units] * world_model_cfg.discount_model.mlp_layers, activation=eval(world_model_cfg.discount_model.dense_act), flatten_dim=None, - norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)] - if world_model_cfg.discount_model.layer_norm - else None, - norm_args=[ - {"normalized_shape": world_model_cfg.discount_model.dense_units} - for _ in range(world_model_cfg.discount_model.mlp_layers) - ] - if world_model_cfg.discount_model.layer_norm - else None, + norm_layer=( + [nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)] + if world_model_cfg.discount_model.layer_norm + else None + ), + norm_args=( + [ + {"normalized_shape": world_model_cfg.discount_model.dense_units} + for _ in range(world_model_cfg.discount_model.mlp_layers) + ] + if world_model_cfg.discount_model.layer_norm + else None + ), ) world_model = WorldModel( encoder.apply(init_weights), @@ -1053,9 +1067,11 @@ def build_agent( activation=eval(critic_cfg.dense_act), flatten_dim=None, norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), ) actor.apply(init_weights) critic.apply(init_weights) @@ -1079,8 +1095,12 @@ def build_agent( world_model.continue_model = fabric.setup_module(world_model.continue_model) actor = fabric.setup_module(actor) critic = fabric.setup_module(critic) + + # Setup target critic with a SingleDeviceStrategy target_critic = copy.deepcopy(critic.module) if target_critic_state: target_critic.load_state_dict(target_critic_state) + single_device_fabric = get_single_device_fabric(fabric) + target_critic = single_device_fabric.setup_module(target_critic) return world_model, actor, critic, target_critic diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 4d4d561f..5c1356ac 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -498,6 +498,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["target_critic"] if cfg.checkpoint.resume_from else None, ) player = PlayerDV2( + fabric, world_model.encoder.module, world_model.rssm.recurrent_model.module, world_model.rssm.representation_model.module, @@ -506,7 +507,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, ) @@ -629,26 +629,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.env.wrapper._target_.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.as_tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.env.wrapper._target_.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.as_tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: normalized_obs = {} for k in obs_keys: torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) @@ -667,58 +667,58 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) - if cfg.dry_run and buffer_type == "episode": - dones = np.ones_like(dones) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: # [N_envs, N_obs] - step_data[k] = real_next_obs[k][np.newaxis] - - # Next_obs becomes the new obs - obs = next_obs - - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) - step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) - step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} - for k in obs_keys: - reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated - for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) - # Reset internal agent states - player.init_states(dones_idxes) + step_data["is_first"] = copy.deepcopy(step_data["dones"]) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + if cfg.dry_run and buffer_type == "episode": + dones = np.ones_like(dones) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + + for k in obs_keys: # [N_envs, N_obs] + step_data[k] = real_next_obs[k][np.newaxis] + + # Next_obs becomes the new obs + obs = next_obs + + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + reset_data["is_first"] = np.ones_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + # Reset internal agent states + player.init_states(dones_idxes) updates_before_training -= 1 @@ -738,7 +738,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): + for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()): tcp.data.copy_(cp.data) batch = {k: v[i].float() for k, v in local_data.items()} train( diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index 2b3b6bab..d03a143c 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -54,6 +54,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["actor"], ) player = PlayerDV2( + fabric, world_model.encoder.module, world_model.rssm.recurrent_model.module, world_model.rssm.representation_model.module, @@ -62,7 +63,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, ) From 18710b339bf96220a6100504ef0685a5594b8780 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 15:15:01 +0100 Subject: [PATCH 14/26] Fix Dreamer-V3 to handle fp16 --- sheeprl/algos/dreamer_v3/agent.py | 47 +++---- sheeprl/algos/dreamer_v3/dreamer_v3.py | 162 +++++++++++++------------ sheeprl/algos/dreamer_v3/evaluate.py | 2 +- 3 files changed, 109 insertions(+), 102 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index fe92cd9a..7f809f73 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -10,10 +10,9 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule -from torch import Tensor, device, nn +from torch import Tensor, nn from torch.distributions import Distribution, Independent, Normal, TanhTransform, TransformedDistribution from torch.distributions.utils import probs_to_logits -from torch.nn.modules import Module from sheeprl.algos.dreamer_v2.agent import WorldModel from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state @@ -24,6 +23,7 @@ OneHotCategoricalValidateArgs, TruncatedNormal, ) +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward from sheeprl.utils.utils import symlog @@ -339,9 +339,9 @@ class RSSM(nn.Module): def __init__( self, - recurrent_model: nn.Module, - representation_model: nn.Module, - transition_model: nn.Module, + recurrent_model: nn.Module | _FabricModule, + representation_model: nn.Module | _FabricModule, + transition_model: nn.Module | _FabricModule, distribution_cfg: Dict[str, Any], discrete: int = 32, unimix: float = 0.01, @@ -485,9 +485,9 @@ class DecoupledRSSM(RSSM): def __init__( self, - recurrent_model: Module, - representation_model: Module, - transition_model: Module, + recurrent_model: nn.Module | _FabricModule, + representation_model: nn.Module | _FabricModule, + transition_model: nn.Module | _FabricModule, distribution_cfg: Dict[str, Any], discrete: int = 32, unimix: float = 0.01, @@ -554,53 +554,54 @@ class PlayerDV3(nn.Module): The model of the Dreamer_v3 player. Args: - encoder (_FabricModule): the encoder. - recurrent_model (_FabricModule): the recurrent model. - representation_model (_FabricModule): the representation model. + fabric (_FabricModule): the fabric module. + encoder (MultiEncoder): the encoder. + rssm (RSSM | DecoupledRSSM): the RSSM model. actor (_FabricModule): the actor. actions_dim (Sequence[int]): the dimension of the actions. num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. recurrent_state_size (int): the size of the recurrent state. - device (torch.device): the device to work on. transition_model (_FabricModule): the transition model. discrete_size (int): the dimension of a single Categorical variable in the stochastic state (prior or posterior). Defaults to 32. actor_type (str, optional): which actor the player is using ('task' or 'exploration'). Default to None. + decoupled_rssm (bool, optional): whether to use the DecoupledRSSM model. """ def __init__( self, - encoder: _FabricModule, + fabric: Fabric, + encoder: MultiEncoder, rssm: RSSM | DecoupledRSSM, - actor: _FabricModule, + actor: Actor | MinedojoActor, actions_dim: Sequence[int], num_envs: int, stochastic_size: int, recurrent_state_size: int, - device: device = "cpu", discrete_size: int = 32, actor_type: str | None = None, decoupled_rssm: bool = False, ) -> None: super().__init__() - self.encoder = encoder + single_device_fabric = get_single_device_fabric(fabric) + self.encoder = single_device_fabric.setup_module(encoder) if decoupled_rssm: rssm_cls = DecoupledRSSM else: rssm_cls = RSSM self.rssm = rssm_cls( - recurrent_model=rssm.recurrent_model.module, - representation_model=rssm.representation_model.module, - transition_model=rssm.transition_model.module, + recurrent_model=single_device_fabric.setup_module(rssm.recurrent_model.module), + representation_model=single_device_fabric.setup_module(rssm.representation_model.module), + transition_model=single_device_fabric.setup_module(rssm.transition_model.module), distribution_cfg=actor.distribution_cfg, discrete=rssm.discrete, unimix=rssm.unimix, ) - self.actor = actor - self.device = device + self.actor = single_device_fabric.setup_module(actor) + self.device = single_device_fabric.device self.actions_dim = actions_dim self.stochastic_size = stochastic_size self.discrete_size = discrete_size @@ -1256,8 +1257,12 @@ def build_agent( world_model.continue_model = fabric.setup_module(world_model.continue_model) actor = fabric.setup_module(actor) critic = fabric.setup_module(critic) + + # Setup target critic with a SingleDeviceStrategy target_critic = copy.deepcopy(critic.module) if target_critic_state: target_critic.load_state_dict(target_critic_state) + single_device_fabric = get_single_device_fabric(fabric) + target_critic = single_device_fabric.setup_module(target_critic) return world_model, actor, critic, target_critic diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 49634131..b539a227 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -466,6 +466,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["target_critic"] if cfg.checkpoint.resume_from else None, ) player = PlayerDV3( + fabric, world_model.encoder.module, world_model.rssm, actor.module, @@ -473,7 +474,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, decoupled_rssm=cfg.algo.decoupled_rssm, ) @@ -586,26 +586,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.env.wrapper._target_.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.as_tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.env.wrapper._target_.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.as_tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: preprocessed_obs = {} for k, v in obs.items(): preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) @@ -623,68 +623,70 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) - rb.add(step_data, validate_args=cfg.buffer.validate_args) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + + step_data["is_first"] = np.zeros_like(step_data["dones"]) + if "restart_on_exception" in infos: + for i, agent_roe in enumerate(infos["restart_on_exception"]): + if agent_roe and not dones[i]: + last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size + rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["dones"][last_inserted_idx] + ) + rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["is_first"][last_inserted_idx] + ) + step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v - step_data["is_first"] = np.zeros_like(step_data["dones"]) - if "restart_on_exception" in infos: - for i, agent_roe in enumerate(infos["restart_on_exception"]): - if agent_roe and not dones[i]: - last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) - rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( - rb.buffer[i]["is_first"][last_inserted_idx] - ) - step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: - step_data[k] = next_obs[k][np.newaxis] - - # next_obs becomes the new obs - obs = next_obs - - rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) - step_data["rewards"] = clip_rewards_fn(rewards) - - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} for k in obs_keys: - reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - - # Reset already inserted step data - step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) - step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) - player.init_states(dones_idxes) + step_data[k] = next_obs[k][np.newaxis] + + # next_obs becomes the new obs + obs = next_obs + + rewards = rewards.reshape((1, cfg.env.num_envs, -1)) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards) + + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = step_data["rewards"][:, dones_idxes] + reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + + # Reset already inserted step data + step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) + step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) + player.init_states(dones_idxes) updates_before_training -= 1 diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index 7fa239fc..fc57a6e3 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -54,6 +54,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["actor"], ) player = PlayerDV3( + fabric, world_model.encoder.module, world_model.rssm, actor.module, @@ -61,7 +62,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, decoupled_rssm=cfg.algo.decoupled_rssm, ) From baba51fc6448039d57ecbf289bc312d5bc3e5978 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 15:31:41 +0100 Subject: [PATCH 15/26] Extract module inside the player --- sheeprl/algos/dreamer_v1/agent.py | 32 ++++++++++++++++---------- sheeprl/algos/dreamer_v1/dreamer_v1.py | 8 +++---- sheeprl/algos/dreamer_v1/evaluate.py | 8 +++---- sheeprl/algos/dreamer_v2/agent.py | 32 ++++++++++++++++---------- sheeprl/algos/dreamer_v2/dreamer_v2.py | 8 +++---- sheeprl/algos/dreamer_v2/evaluate.py | 8 +++---- sheeprl/algos/dreamer_v3/agent.py | 20 ++++++++++------ sheeprl/algos/dreamer_v3/dreamer_v3.py | 4 ++-- sheeprl/algos/dreamer_v3/evaluate.py | 4 ++-- 9 files changed, 73 insertions(+), 51 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index d541a2e4..d9b69d54 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -223,10 +223,10 @@ class PlayerDV1(nn.Module): Args: fabric (Fabric): the fabric object. - encoder (nn.Module): the encoder. - recurrent_model (nn.Module): the recurrent model. - representation_model (nn.Module): the representation model. - actor (nn.Module): the actor. + encoder (nn.Module| _FabricModule): the encoder. + recurrent_model (nn.Module| _FabricModule): the recurrent model. + representation_model (nn.Module| _FabricModule): the representation model. + actor (nn.Module| _FabricModule): the actor. actions_dim (Sequence[int]): the dimension of each action. num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. @@ -238,10 +238,10 @@ class PlayerDV1(nn.Module): def __init__( self, fabric: Fabric, - encoder: nn.Module, - recurrent_model: nn.Module, - representation_model: nn.Module, - actor: nn.Module, + encoder: nn.Module | _FabricModule, + recurrent_model: nn.Module | _FabricModule, + representation_model: nn.Module | _FabricModule, + actor: nn.Module | _FabricModule, actions_dim: Sequence[int], num_envs: int, stochastic_size: int, @@ -250,10 +250,18 @@ def __init__( ) -> None: super().__init__() single_device_fabric = get_single_device_fabric(fabric) - self.encoder = single_device_fabric.setup_module(encoder) - self.recurrent_model = single_device_fabric.setup_module(recurrent_model) - self.representation_model = single_device_fabric.setup_module(representation_model) - self.actor = single_device_fabric.setup_module(actor) + self.encoder = single_device_fabric.setup_module( + getattr(encoder, "module", encoder), + ) + self.recurrent_model = single_device_fabric.setup_module( + getattr(recurrent_model, "module", recurrent_model), + ) + self.representation_model = single_device_fabric.setup_module( + getattr(representation_model, "module", representation_model) + ) + self.actor = single_device_fabric.setup_module( + getattr(actor, "module", actor), + ) self.device = single_device_fabric.device self.actions_dim = actions_dim self.stochastic_size = stochastic_size diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 4c582609..0326a142 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -482,10 +482,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) player = PlayerDV1( fabric, - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor.module, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index e33d1f54..4481a501 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -55,10 +55,10 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): ) player = PlayerDV1( fabric, - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor.module, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 7f80eb76..6f546f43 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -745,10 +745,10 @@ class PlayerDV2(nn.Module): Args: fabric: the fabric of the model. - encoder (nn.Module): the encoder. - recurrent_model (nn.Module): the recurrent model. - representation_model (nn.Module): the representation model. - actor (nn.Module): the actor. + encoder (nn.Module | _FabricModule): the encoder. + recurrent_model (nn.Module | _FabricModule): the recurrent model. + representation_model (nn.Module | _FabricModule): the representation model. + actor (nn.Module | _FabricModule): the actor. actions_dim (Sequence[int]): the dimension of the actions. num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. @@ -763,10 +763,10 @@ class PlayerDV2(nn.Module): def __init__( self, fabric: Fabric, - encoder: nn.Module, - recurrent_model: nn.Module, - representation_model: nn.Module, - actor: nn.Module, + encoder: nn.Module | _FabricModule, + recurrent_model: nn.Module | _FabricModule, + representation_model: nn.Module | _FabricModule, + actor: nn.Module | _FabricModule, actions_dim: Sequence[int], num_envs: int, stochastic_size: int, @@ -776,10 +776,18 @@ def __init__( ) -> None: super().__init__() fabric_player = get_single_device_fabric(fabric) - self.encoder = fabric_player.setup_module(encoder) - self.recurrent_model = fabric_player.setup_module(recurrent_model) - self.representation_model = fabric_player.setup_module(representation_model) - self.actor = fabric_player.setup_module(actor) + self.encoder = fabric_player.setup_module( + getattr(encoder, "module", encoder), + ) + self.recurrent_model = fabric_player.setup_module( + getattr(recurrent_model, "module", recurrent_model), + ) + self.representation_model = fabric_player.setup_module( + getattr(representation_model, "module", representation_model), + ) + self.actor = fabric_player.setup_module( + getattr(actor, "module", actor), + ) self.device = fabric_player.device self.actions_dim = actions_dim self.stochastic_size = stochastic_size diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 5c1356ac..613ce927 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -499,10 +499,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) player = PlayerDV2( fabric, - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor.module, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index d03a143c..29515b7d 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -55,10 +55,10 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): ) player = PlayerDV2( fabric, - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor.module, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 7f809f73..f29546dc 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -574,9 +574,9 @@ class PlayerDV3(nn.Module): def __init__( self, fabric: Fabric, - encoder: MultiEncoder, + encoder: MultiEncoder | _FabricModule, rssm: RSSM | DecoupledRSSM, - actor: Actor | MinedojoActor, + actor: Actor | MinedojoActor | _FabricModule, actions_dim: Sequence[int], num_envs: int, stochastic_size: int, @@ -587,20 +587,26 @@ def __init__( ) -> None: super().__init__() single_device_fabric = get_single_device_fabric(fabric) - self.encoder = single_device_fabric.setup_module(encoder) + self.encoder = single_device_fabric.setup_module(getattr(encoder, "module", encoder)) if decoupled_rssm: rssm_cls = DecoupledRSSM else: rssm_cls = RSSM self.rssm = rssm_cls( - recurrent_model=single_device_fabric.setup_module(rssm.recurrent_model.module), - representation_model=single_device_fabric.setup_module(rssm.representation_model.module), - transition_model=single_device_fabric.setup_module(rssm.transition_model.module), + recurrent_model=single_device_fabric.setup_module( + getattr(rssm.recurrent_model, "module", rssm.recurrent_model) + ), + representation_model=single_device_fabric.setup_module( + getattr(rssm.representation_model, "module", rssm.representation_model) + ), + transition_model=single_device_fabric.setup_module( + getattr(rssm.transition_model, "module", rssm.transition_model) + ), distribution_cfg=actor.distribution_cfg, discrete=rssm.discrete, unimix=rssm.unimix, ) - self.actor = single_device_fabric.setup_module(actor) + self.actor = single_device_fabric.setup_module(getattr(actor, "module", actor)) self.device = single_device_fabric.device self.actions_dim = actions_dim self.stochastic_size = stochastic_size diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index b539a227..7984e85c 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -467,9 +467,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) player = PlayerDV3( fabric, - world_model.encoder.module, + world_model.encoder, world_model.rssm, - actor.module, + actor, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index fc57a6e3..12f885c6 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -55,9 +55,9 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): ) player = PlayerDV3( fabric, - world_model.encoder.module, + world_model.encoder, world_model.rssm, - actor.module, + actor, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, From 747d9fc8283f93c3c7ca60077e55a9d028171650 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 15:33:14 +0100 Subject: [PATCH 16/26] Fix P2E-DV1 to handle fp16 --- sheeprl/algos/p2e_dv1/evaluate.py | 10 +++++----- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 10 +++++----- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index d170e436..2d2d9bf9 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -55,15 +55,15 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["actor_task"], ) player = PlayerDV1( - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor_task.module, + fabric, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor_task, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, ) test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 3af2c83c..b50c7cd5 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -475,15 +475,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) player = PlayerDV1( - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor_exploration.module, + fabric, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor_exploration, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, actor_type=cfg.algo.player.actor_type, ) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index a9596ddf..06b6ade0 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -143,15 +143,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) player = PlayerDV1( - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor_exploration.module if cfg.algo.player.actor_type == "exploration" else actor_task.module, + fabric, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor_exploration if cfg.algo.player.actor_type == "exploration" else actor_task, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, actor_type=cfg.algo.player.actor_type, ) From a3e44714046d5aea84309bde3b38a3914754ed2f Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 15:33:30 +0100 Subject: [PATCH 17/26] Fix p2e-dv2 to handle fp16 --- sheeprl/algos/p2e_dv2/evaluate.py | 10 +++++----- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 10 +++++----- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index 66c1173e..26ef629d 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -56,15 +56,15 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["actor_task"], ) player = PlayerDV2( - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor_task.module, + fabric, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor_task, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, ) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 3d6c073c..94f3dc7e 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -607,15 +607,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) player = PlayerDV2( - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor_exploration.module, + fabric, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor_exploration, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, actor_type=cfg.algo.player.actor_type, ) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index c7909c47..143b557e 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -147,15 +147,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) player = PlayerDV2( - world_model.encoder.module, - world_model.rssm.recurrent_model.module, - world_model.rssm.representation_model.module, - actor_exploration.module if cfg.algo.player.actor_type == "exploration" else actor_task.module, + fabric, + world_model.encoder, + world_model.rssm.recurrent_model, + world_model.rssm.representation_model, + actor_exploration if cfg.algo.player.actor_type == "exploration" else actor_task, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, actor_type=cfg.algo.player.actor_type, ) From e64bac1bd97388e8fc9c7428fc40c922b092f7d3 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 15:34:42 +0100 Subject: [PATCH 18/26] Fix p2e-dv3 to handle fp16 --- sheeprl/algos/p2e_dv3/evaluate.py | 6 +++--- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 6 +++--- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py index 7aadb93f..9dbf8f7a 100644 --- a/sheeprl/algos/p2e_dv3/evaluate.py +++ b/sheeprl/algos/p2e_dv3/evaluate.py @@ -56,15 +56,15 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["actor_task"], ) player = PlayerDV3( - world_model.encoder.module, + fabric, + world_model.encoder, world_model.rssm, - actor.module, + actor, actions_dim, cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, actor_type="task", ) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index d9751131..918f42aa 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -640,14 +640,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) player = PlayerDV3( - world_model.encoder.module, + fabric, + world_model.encoder, world_model.rssm, - actor_exploration.module, + actor_exploration, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, actor_type=cfg.algo.player.actor_type, ) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index f44405e8..bd20456c 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -150,14 +150,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # initialize the ensembles with different seeds to be sure they have different weights player = PlayerDV3( - world_model.encoder.module, + fabric, + world_model.encoder, world_model.rssm, - actor_exploration.module if cfg.algo.player.actor_type == "exploration" else actor_task.module, + actor_exploration if cfg.algo.player.actor_type == "exploration" else actor_task, actions_dim, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, - fabric.device, discrete_size=cfg.algo.world_model.discrete_size, actor_type=cfg.algo.player.actor_type, ) From 61d288171945ea48e9685b277bd64fba08c29b38 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 22 Mar 2024 15:39:58 +0100 Subject: [PATCH 19/26] Fix A2C to handle fp16 --- sheeprl/algos/a2c/a2c.py | 92 +++++++++++++++++------------------ sheeprl/algos/a2c/agent.py | 16 +++++- sheeprl/algos/a2c/evaluate.py | 2 +- 3 files changed, 61 insertions(+), 49 deletions(-) diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index b9dc0424..13d6109f 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -163,7 +163,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Given that the environment has been created with the `make_env` method, the agent # forward method must accept as input a dictionary like {"obs1_name": obs1, "obs2_name": obs2, ...}. # The agent should be able to process both image and vector-like observations. - agent = build_agent( + agent, player = build_agent( fabric, actions_dim, is_continuous, @@ -224,68 +224,68 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data[k] = next_obs[k][np.newaxis] for update in range(1, num_updates + 1): - for _ in range(0, cfg.algo.rollout_steps): - policy_step += cfg.env.num_envs * world_size + with torch.inference_mode(): + for _ in range(0, cfg.algo.rollout_steps): + policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment # This calls the `forward` method of the PyTorch module, escaping from Fabric # because we don't want this to be a synchronization point torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - actions, _, values = agent.module(torch_obs) + actions, _, values = player(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: real_actions = torch.cat([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy() actions = torch.cat(actions, -1).cpu().numpy() - # Single environment step - obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) - - dones = np.logical_or(done, truncated) - dones = dones.reshape(cfg.env.num_envs, -1) - rewards = rewards.reshape(cfg.env.num_envs, -1) - - # Update the step data - step_data["dones"] = dones[np.newaxis] - step_data["values"] = values.cpu().numpy()[np.newaxis] - step_data["actions"] = actions[np.newaxis] - step_data["rewards"] = rewards[np.newaxis] - if cfg.buffer.memmap: - step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) - step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) - - # Append data to buffer - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Update the observation and dones - next_obs = {} - for k in obs_keys: - _obs = obs[k] - step_data[k] = _obs[np.newaxis] - next_obs[k] = _obs - - if cfg.metric.log_level > 0 and "final_info" in info: - for i, agent_ep_info in enumerate(info["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and "Rewards/rew_avg" in aggregator: - aggregator.update("Rewards/rew_avg", ep_rew) - if aggregator and "Game/ep_len_avg" in aggregator: - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + # Single environment step + obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + + dones = np.logical_or(done, truncated) + dones = dones.reshape(cfg.env.num_envs, -1) + rewards = rewards.reshape(cfg.env.num_envs, -1) + + # Update the step data + step_data["dones"] = dones[np.newaxis] + step_data["values"] = values.cpu().numpy()[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = rewards[np.newaxis] + if cfg.buffer.memmap: + step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) + step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) + + # Append data to buffer + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Update the observation and dones + next_obs = {} + for k in obs_keys: + _obs = obs[k] + step_data[k] = _obs[np.newaxis] + next_obs[k] = _obs + + if cfg.metric.log_level > 0 and "final_info" in info: + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and "Rewards/rew_avg" in aggregator: + aggregator.update("Rewards/rew_avg", ep_rew) + if aggregator and "Game/ep_len_avg" in aggregator: + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Transform the data into PyTorch Tensors local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy) # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) - with torch.no_grad(): + with torch.inference_mode(): torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - next_values = agent.module.get_value(torch_obs) + _, _, next_values = player(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], diff --git a/sheeprl/algos/a2c/agent.py b/sheeprl/algos/a2c/agent.py index 100b7d41..73a83f06 100644 --- a/sheeprl/algos/a2c/agent.py +++ b/sheeprl/algos/a2c/agent.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Dict, List, Optional, Sequence, Tuple import gymnasium @@ -10,6 +11,7 @@ from sheeprl.models.models import MLP from sheeprl.utils.distribution import OneHotCategoricalValidateArgs +from sheeprl.utils.fabric import get_single_device_fabric class MLPEncoder(nn.Module): @@ -181,7 +183,7 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, agent_state: Optional[Dict[str, Tensor]] = None, -) -> _FabricModule: +) -> Tuple[_FabricModule, _FabricModule]: agent = A2CAgent( actions_dim=actions_dim, obs_space=obs_space, @@ -194,6 +196,16 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) + player = copy.deepcopy(agent) + + # Setup training agent agent = fabric.setup_module(agent) - return agent + # Setup player agent + fabric_player = get_single_device_fabric(fabric) + player = fabric_player.setup_module(player) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.parameters(), player.parameters()): + player_p.data = agent_p.data + return agent, player diff --git a/sheeprl/algos/a2c/evaluate.py b/sheeprl/algos/a2c/evaluate.py index 612fdb98..c3ee6a3b 100644 --- a/sheeprl/algos/a2c/evaluate.py +++ b/sheeprl/algos/a2c/evaluate.py @@ -54,5 +54,5 @@ def evaluate_a2c(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + agent, _ = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) test(agent, fabric, cfg, log_dir) From 7dc8cc4444e6ee70e39456fa1c85a136ebfe8009 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 27 Mar 2024 11:13:43 +0100 Subject: [PATCH 20/26] Add `greedy` flag to actor forward method --- sheeprl/algos/a2c/a2c.py | 2 +- sheeprl/algos/a2c/agent.py | 35 +++++++------------------- sheeprl/algos/a2c/utils.py | 8 +++--- sheeprl/algos/ppo/agent.py | 13 ---------- sheeprl/algos/ppo/utils.py | 2 +- sheeprl/algos/ppo_recurrent/agent.py | 37 ++++------------------------ sheeprl/algos/ppo_recurrent/utils.py | 2 +- sheeprl/algos/sac/agent.py | 25 ++++++------------- sheeprl/algos/sac/utils.py | 2 +- sheeprl/algos/sac_ae/agent.py | 32 +++++++++--------------- sheeprl/algos/sac_ae/utils.py | 2 +- tests/test_algos/test_algos.py | 9 ++++--- 12 files changed, 47 insertions(+), 122 deletions(-) diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index 13d6109f..bca58a05 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -351,7 +351,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: - test(agent.module, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/algos/a2c/agent.py b/sheeprl/algos/a2c/agent.py index 73a83f06..91d6412e 100644 --- a/sheeprl/algos/a2c/agent.py +++ b/sheeprl/algos/a2c/agent.py @@ -9,6 +9,7 @@ from torch import Tensor from torch.distributions import Distribution, Independent, Normal +from sheeprl.algos.ppo.agent import PPOActor from sheeprl.models.models import MLP from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.fabric import get_single_device_fabric @@ -94,7 +95,7 @@ def __init__( ) # Actor - self.actor_backbone = MLP( + actor_backbone = MLP( input_dims=features_dim, output_dim=None, hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers, @@ -109,19 +110,17 @@ def __init__( ) if is_continuous: # Output is a tuple of two elements: mean and log_std, one for every action - self.actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)]) + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)]) else: # Output is a tuple of one element: logits, one for every action - self.actor_heads = nn.ModuleList( - [nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim] - ) + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) + self.actor = PPOActor(actor_backbone, actor_heads, is_continuous=is_continuous) def forward( - self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None + self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None, greedy: bool = False ) -> Tuple[Sequence[Tensor], Tensor, Tensor]: feat = self.feature_extractor(obs) - out: Tensor = self.actor_backbone(feat) - pre_dist: List[Tensor] = [head(out) for head in self.actor_heads] + pre_dist: List[Tensor] = self.actor(feat) values = self.critic(feat) if self.is_continuous: mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1) @@ -132,7 +131,7 @@ def forward( validate_args=self.distribution_cfg.validate_args, ) if actions is None: - actions = normal.sample() + actions = normal.mode if greedy else normal.sample() else: # always composed by a tuple of one element containing all the # continuous actions @@ -151,7 +150,7 @@ def forward( OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args) ) if should_append: - actions.append(actions_dist[-1].sample()) + actions.append(actions_dist[-1].mode if greedy else actions_dist[-1].sample()) actions_logprobs.append(actions_dist[-1].log_prob(actions[i])) return tuple(actions), torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), values @@ -159,22 +158,6 @@ def get_value(self, obs: Dict[str, Tensor]) -> Tensor: feat = self.feature_extractor(obs) return self.critic(feat) - def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]: - feat = self.feature_extractor(obs) - out = self.actor_backbone(feat) - pre_dist: List[Tensor] = [head(out) for head in self.actor_heads] - if self.is_continuous: - # Just take the mean of the distribution - return [torch.chunk(pre_dist[0], 2, -1)[0]] - else: - # Take the mode of the distribution - return tuple( - [ - OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args).mode - for logits in pre_dist - ] - ) - def build_agent( fabric: Fabric, diff --git a/sheeprl/algos/a2c/utils.py b/sheeprl/algos/a2c/utils.py index e38a6414..430cdc77 100644 --- a/sheeprl/algos/a2c/utils.py +++ b/sheeprl/algos/a2c/utils.py @@ -2,6 +2,7 @@ import torch from lightning import Fabric +from lightning.fabric.wrappers import _FabricModule from sheeprl.algos.a2c.agent import A2CAgent from sheeprl.utils.env import make_env @@ -10,7 +11,7 @@ @torch.no_grad() -def test(agent: A2CAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(agent: A2CAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False @@ -25,10 +26,11 @@ def test(agent: A2CAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): while not done: # Act greedly through the environment + actions, _, _ = agent(obs, greedy=True) if agent.is_continuous: - actions = torch.cat(agent.get_greedy_actions(obs), dim=-1) + actions = torch.cat(actions, dim=-1) else: - actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1) + actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) # Single environment step o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index dac3e758..82d3bfff 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -206,19 +206,6 @@ def get_value(self, obs: Dict[str, Tensor]) -> Tensor: feat = self.feature_extractor(obs) return self.critic(feat) - def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]: - feat = self.feature_extractor(obs) - actor_out: List[Tensor] = self.actor(feat) - if self.actor.is_continuous: - return [torch.chunk(actor_out[0], 2, -1)[0]] - else: - return tuple( - [ - OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args).mode - for logits in actor_out - ] - ) - def build_agent( fabric: Fabric, diff --git a/sheeprl/algos/ppo/utils.py b/sheeprl/algos/ppo/utils.py index 06ac2b8d..aee52933 100644 --- a/sheeprl/algos/ppo/utils.py +++ b/sheeprl/algos/ppo/utils.py @@ -41,7 +41,7 @@ def test(agent: PPOAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], l while not done: # Act greedly through the environment - actions = agent.get_greedy_actions(obs) + actions, _, _, _ = agent(obs, greedy=True) if agent.is_continuous: actions = torch.cat(actions, dim=-1) else: diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py index 84a64c5d..814d3a0f 100644 --- a/sheeprl/algos/ppo_recurrent/agent.py +++ b/sheeprl/algos/ppo_recurrent/agent.py @@ -191,32 +191,8 @@ def reset_hidden_states(self) -> Tuple[Tensor, Tensor]: ) return states - def get_greedy_actions( - self, - obs: Dict[str, Tensor], - prev_states: Tuple[Tensor, Tensor], - prev_actions: Tensor, - mask: Optional[Tensor] = None, - ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, Tensor]]: - embedded_obs = self.feature_extractor(obs) - out, states = self.rnn(torch.cat((embedded_obs, prev_actions), dim=-1), prev_states, mask) - pre_dist = self.get_pre_dist(out) - actions = [] - if self.is_continuous: - dist = Independent( - Normal(*pre_dist, validate_args=self.distribution_cfg.validate_args), - 1, - validate_args=self.distribution_cfg.validate_args, - ) - actions.append(dist.mode) - else: - for logits in pre_dist: - dist = OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args) - actions.append(dist.mode) - return tuple(actions), states - - def get_sampled_actions( - self, pre_dist: Tuple[Tensor, ...], actions: Optional[List[Tensor]] = None + def get_actions( + self, pre_dist: Tuple[Tensor, ...], actions: Optional[List[Tensor]] = None, greedy: bool = False ) -> Tuple[Tuple[Tensor, ...], Tensor, Tensor]: logprobs = [] entropies = [] @@ -228,19 +204,16 @@ def get_sampled_actions( validate_args=self.distribution_cfg.validate_args, ) if actions is None: - actions = dist.sample() + sampled_actions.append(dist.mode if greedy else dist.sample()) else: - # always composed by a tuple of one element containing all the - # continuous actions - actions = actions[0] - sampled_actions.append(actions) + sampled_actions.append(actions[0]) entropies.append(dist.entropy()) logprobs.append(dist.log_prob(actions)) else: for i, logits in enumerate(pre_dist): dist = OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args) if actions is None: - sampled_actions.append(dist.sample()) + sampled_actions.append(dist.mode if greedy else dist.sample()) else: sampled_actions.append(actions[i]) entropies.append(dist.entropy()) diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index 34125f76..735b6b0b 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -47,7 +47,7 @@ def test(agent: "RecurrentPPOAgent", fabric: Fabric, cfg: Dict[str, Any], log_di actions = torch.zeros(1, 1, sum(agent.actions_dim), device=fabric.device) while not done: # Act greedly through the environment - actions, state = agent.get_greedy_actions(next_obs, state, actions) + actions, _, _, _, state = agent(next_obs, state, actions, greedy=True) if agent.is_continuous: real_actions = torch.cat(actions, -1) actions = torch.cat(actions, dim=-1).view(1, 1, -1) diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index 6d877365..bb2cecbe 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -88,7 +88,7 @@ def __init__( self.register_buffer("action_scale", torch.tensor((action_high - action_low) / 2.0, dtype=torch.float32)) self.register_buffer("action_bias", torch.tensor((action_high + action_low) / 2.0, dtype=torch.float32)) - def forward(self, obs: Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, obs: Tensor, greedy: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Given an observation, it returns a tanh-squashed sampled action (correctly rescaled to the environment action bounds) and its log-prob (as defined in Eq. 26 of https://arxiv.org/abs/1812.05905) @@ -102,9 +102,12 @@ def forward(self, obs: Tensor) -> Tuple[Tensor, Tensor]: """ x = self.model(obs) mean = self.fc_mean(x) - log_std = self.fc_logstd(x) - std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp() - return self.get_actions_and_log_probs(mean, std) + if greedy: + return torch.tanh(mean) * self.action_scale + self.action_bias + else: + log_std = self.fc_logstd(x) + std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp() + return self.get_actions_and_log_probs(mean, std) def get_actions_and_log_probs(self, mean: Tensor, std: Tensor): """Given the mean and the std of a Normal distribution, it returns a tanh-squashed @@ -140,20 +143,6 @@ def get_actions_and_log_probs(self, mean: Tensor, std: Tensor): return action, log_prob - def get_greedy_actions(self, obs: Tensor) -> Tensor: - """Get the action given the input observation greedily - - Args: - obs (Tensor): input observation - - Returns: - action - """ - x = self.model(obs) - mean = self.fc_mean(x) - mean = torch.tanh(mean) * self.action_scale + self.action_bias - return mean - class SACAgent(nn.Module): def __init__( diff --git a/sheeprl/algos/sac/utils.py b/sheeprl/algos/sac/utils.py index 5d16e45f..3fe14d31 100644 --- a/sheeprl/algos/sac/utils.py +++ b/sheeprl/algos/sac/utils.py @@ -41,7 +41,7 @@ def test(actor: SACActor, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): ) # [N_envs, N_obs] while not done: # Act greedly through the environment - action = actor.get_greedy_actions(next_obs) + action = actor(next_obs, greedy=True) # Single environment step next_obs, reward, done, truncated, info = env.step(action.cpu().numpy().reshape(env.action_space.shape)) diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index 46ac1fc4..d51768af 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -258,7 +258,9 @@ def __init__( # Orthogonal init self.apply(weight_init) - def forward(self, obs: Tensor, detach_encoder_features: bool = False) -> Tuple[Tensor, Tensor]: + def forward( + self, obs: Tensor, detach_encoder_features: bool = False, greedy: bool = False + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Given an observation, it returns a tanh-squashed sampled action (correctly rescaled to the environment action bounds) and its log-prob (as defined in Eq. 26 of https://arxiv.org/abs/1812.05905) @@ -273,11 +275,14 @@ def forward(self, obs: Tensor, detach_encoder_features: bool = False) -> Tuple[T features = self.encoder(obs, detach_encoder_features=detach_encoder_features) x = self.model(features) mean = self.fc_mean(x) - log_std = self.fc_logstd(x) - # log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) - log_std = torch.tanh(log_std) - log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) - return self.get_actions_and_log_probs(mean, log_std.exp()) + if greedy: + return torch.tanh(mean) * self.action_scale + self.action_bias + else: + log_std = self.fc_logstd(x) + # log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) + log_std = torch.tanh(log_std) + log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) + return self.get_actions_and_log_probs(mean, log_std.exp()) def get_actions_and_log_probs(self, mean: Tensor, std: Tensor): """Given the mean and the std of a Normal distribution, it returns a tanh-squashed @@ -313,21 +318,6 @@ def get_actions_and_log_probs(self, mean: Tensor, std: Tensor): return action, log_prob - def get_greedy_actions(self, obs: Tensor) -> Tensor: - """Get the action given the input observation greedily - - Args: - obs (Tensor): input observation - - Returns: - action - """ - features = self.encoder(obs) - x = self.model(features) - mean = self.fc_mean(x) - mean = torch.tanh(mean) * self.action_scale + self.action_bias - return mean - class SACAEAgent(nn.Module): def __init__( diff --git a/sheeprl/algos/sac_ae/utils.py b/sheeprl/algos/sac_ae/utils.py index c25cd22c..2a9e542b 100644 --- a/sheeprl/algos/sac_ae/utils.py +++ b/sheeprl/algos/sac_ae/utils.py @@ -45,7 +45,7 @@ def test(actor: "SACAEContinuousActor", fabric: Fabric, cfg: Dict[str, Any], log while not done: # Act greedly through the environment - action = actor.get_greedy_actions(next_obs) + action = actor(next_obs, greedy=True) # Single environment step o, reward, done, truncated, _ = env.step(action.cpu().numpy().reshape(env.action_space.shape)) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index a0c2d844..c078ed2c 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -26,12 +26,14 @@ def standard_args(): "hydra/hydra_logging=disabled", "dry_run=True", "checkpoint.save_last=False", - "metric.log_level=0", - "metric.disable_timer=True", "env.num_envs=1", - "fabric.devices=auto", f"env.sync_env={_IS_WINDOWS}", "env.capture_video=False", + "fabric.devices=auto", + "fabric.accelerator=cpu", + "fabric.precision=bf16-true", + "metric.log_level=0", + "metric.disable_timer=True", ] if os.environ.get("MLFLOW_TRACKING_URI", None) is not None: args.extend(["logger@metric.logger=mlflow", "model_manager.disabled=False", "metric.log_level=1"]) @@ -45,7 +47,6 @@ def start_time(): @pytest.fixture(autouse=True) def mock_env_and_destroy(devices): - os.environ["LT_ACCELERATOR"] = "cpu" os.environ["LT_DEVICES"] = str(devices) if _IS_WINDOWS and devices != "1": pytest.skip() From 33c9d84cbff8a26b146f087c644193891af3215e Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 27 Mar 2024 11:32:41 +0100 Subject: [PATCH 21/26] Fix PPORecurrent arg postion --- sheeprl/algos/ppo_recurrent/agent.py | 21 ++++++++++++++------- sheeprl/algos/ppo_recurrent/utils.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py index 814d3a0f..e76902f4 100644 --- a/sheeprl/algos/ppo_recurrent/agent.py +++ b/sheeprl/algos/ppo_recurrent/agent.py @@ -203,19 +203,25 @@ def get_actions( 1, validate_args=self.distribution_cfg.validate_args, ) - if actions is None: - sampled_actions.append(dist.mode if greedy else dist.sample()) + if greedy: + sampled_actions.append(dist.mode) else: - sampled_actions.append(actions[0]) + if actions is None: + sampled_actions.append(dist.sample()) + else: + sampled_actions.append(actions[0]) entropies.append(dist.entropy()) logprobs.append(dist.log_prob(actions)) else: for i, logits in enumerate(pre_dist): dist = OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args) - if actions is None: - sampled_actions.append(dist.mode if greedy else dist.sample()) + if greedy: + sampled_actions.append(dist.mode) else: - sampled_actions.append(actions[i]) + if actions is None: + sampled_actions.append(dist.sample()) + else: + sampled_actions.append(actions[i]) entropies.append(dist.entropy()) logprobs.append(dist.log_prob(sampled_actions[-1])) return ( @@ -243,6 +249,7 @@ def forward( prev_states: Tuple[Tensor, Tensor], actions: Optional[List[Tensor]] = None, mask: Optional[Tensor] = None, + greedy: bool = False, ) -> Tuple[Tuple[Tensor, ...], Tensor, Tensor, Tensor, Tuple[Tensor, Tensor]]: """Compute actor logits and critic values. @@ -264,7 +271,7 @@ def forward( out, states = self.rnn(torch.cat((embedded_obs, prev_actions), dim=-1), prev_states, mask) values = self.get_values(out) pre_dist = self.get_pre_dist(out) - actions, logprobs, entropies = self.get_sampled_actions(pre_dist, actions) + actions, logprobs, entropies = self.get_actions(pre_dist, actions, greedy=greedy) return actions, logprobs, entropies, values, states diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index 735b6b0b..64f388f8 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -47,7 +47,7 @@ def test(agent: "RecurrentPPOAgent", fabric: Fabric, cfg: Dict[str, Any], log_di actions = torch.zeros(1, 1, sum(agent.actions_dim), device=fabric.device) while not done: # Act greedly through the environment - actions, _, _, _, state = agent(next_obs, state, actions, greedy=True) + actions, _, _, _, state = agent(next_obs, actions, state, greedy=True) if agent.is_continuous: real_actions = torch.cat(actions, -1) actions = torch.cat(actions, dim=-1).view(1, 1, -1) From 9f4f26e7e23636abc846673994f9490d9335b35c Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 27 Mar 2024 12:49:32 +0100 Subject: [PATCH 22/26] Let P2E algos handle fp16 --- sheeprl/__init__.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 141 ++++++++-------- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 111 +++++++------ sheeprl/algos/p2e_dv2/agent.py | 11 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 145 +++++++++-------- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 114 +++++++------ sheeprl/algos/p2e_dv3/agent.py | 13 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 161 ++++++++++--------- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 132 +++++++-------- tests/test_algos/test_algos.py | 1 + 10 files changed, 419 insertions(+), 412 deletions(-) diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index 46d30d43..5ec12dc8 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -52,7 +52,7 @@ np.int = np.int64 np.bool = bool -__version__ = "0.5.4" +__version__ = "0.5.5.dev0" # Replace `moviepy.decorators.use_clip_fps_by_default` method to work with python 3.8, 3.9, and 3.10 diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index b50c7cd5..3e0d7bb3 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -626,26 +626,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.env.wrapper._target_.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.as_tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.env.wrapper._target_.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.as_tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: normalized_obs = {} for k in obs_keys: torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) @@ -663,57 +663,57 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: - if k in cfg.algo.cnn_keys.encoder: - next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) - real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) - step_data[k] = real_next_obs[k][np.newaxis] - - # next_obs becomes the new obs - obs = next_obs - - step_data["dones"] = dones[np.newaxis] - step_data["actions"] = actions[np.newaxis] - step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + for k in obs_keys: - reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated - for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) - # Reset internal agent states - player.init_states(reset_envs=dones_idxes) + if k in cfg.algo.cnn_keys.encoder: + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) + step_data[k] = real_next_obs[k][np.newaxis] + + # next_obs becomes the new obs + obs = next_obs + + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + # Reset internal agent states + player.init_states(reset_envs=dones_idxes) updates_before_training -= 1 @@ -835,7 +835,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 06b6ade0..754fda8c 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -263,15 +263,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() - player.init_states() for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): normalized_obs = {} for k in obs_keys: torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) @@ -289,64 +288,63 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: - if k in cfg.algo.cnn_keys.encoder: - next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) - real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) - step_data[k] = real_next_obs[k][np.newaxis] - - # next_obs becomes the new obs - obs = next_obs - - step_data["dones"] = dones[np.newaxis] - step_data["actions"] = actions[np.newaxis] - step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + for k in obs_keys: - reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated - for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) - # Reset internal agent states - player.init_states(reset_envs=dones_idxes) + if k in cfg.algo.cnn_keys.encoder: + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) + step_data[k] = real_next_obs[k][np.newaxis] + + # next_obs becomes the new obs + obs = next_obs + + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + # Reset internal agent states + player.init_states(reset_envs=dones_idxes) updates_before_training -= 1 # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task.module player.actor_type = "task" with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(cfg.algo.per_rank_gradient_steps): @@ -452,7 +450,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") diff --git a/sheeprl/algos/p2e_dv2/agent.py b/sheeprl/algos/p2e_dv2/agent.py index ffb146cd..f89243ba 100644 --- a/sheeprl/algos/p2e_dv2/agent.py +++ b/sheeprl/algos/p2e_dv2/agent.py @@ -14,6 +14,7 @@ from sheeprl.algos.dreamer_v2.agent import WorldModel from sheeprl.algos.dreamer_v2.agent import build_agent as dv2_build_agent from sheeprl.models.models import MLP +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.utils import init_weights # In order to use the hydra.utils.get_class method, in this way the user can @@ -116,9 +117,11 @@ def build_agent( activation=eval(critic_cfg.dense_act), flatten_dim=None, norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), ) actor_task.apply(init_weights) critic_task.apply(init_weights) @@ -135,6 +138,8 @@ def build_agent( target_critic_task = copy.deepcopy(critic_task.module) if target_critic_task_state: target_critic_task.load_state_dict(target_critic_task_state) + single_device_fabric = get_single_device_fabric(fabric) + target_critic_task = single_device_fabric.setup_module(target_critic_task) # initialize the ensembles with different seeds to be sure they have different weights ens_list = [] diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 94f3dc7e..d17c1ec1 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -773,26 +773,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.env.wrapper._target_.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.as_tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.env.wrapper._target_.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.as_tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: normalized_obs = {} for k in obs_keys: torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) @@ -811,58 +811,58 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) - if cfg.dry_run and buffer_type == "episode": - dones = np.ones_like(dones) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: # [N_envs, N_obs] - step_data[k] = real_next_obs[k][np.newaxis] - - # Next_obs becomes the new obs - obs = next_obs - - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) - step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) - step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} - for k in obs_keys: - reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated - for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) - # Reset internal agent states - player.init_states(dones_idxes) + step_data["is_first"] = copy.deepcopy(step_data["dones"]) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + if cfg.dry_run and buffer_type == "episode": + dones = np.ones_like(dones) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + + for k in obs_keys: # [N_envs, N_obs] + step_data[k] = real_next_obs[k][np.newaxis] + + # Next_obs becomes the new obs + obs = next_obs + + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + reset_data["is_first"] = np.ones_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + # Reset internal agent states + player.init_states(dones_idxes) updates_before_training -= 1 @@ -1000,7 +1000,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 143b557e..8bfa7d86 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -288,10 +288,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): normalized_obs = {} for k in obs_keys: torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) @@ -310,65 +310,64 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) - if cfg.dry_run and buffer_type == "episode": - dones = np.ones_like(dones) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: # [N_envs, N_obs] - step_data[k] = real_next_obs[k][np.newaxis] - - # Next_obs becomes the new obs - obs = next_obs - - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) - step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) - step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) - rb.add(step_data, validate_args=cfg.buffer.validate_args) - - # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} - for k in obs_keys: - reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated - for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) - # Reset internal agent states - player.init_states(dones_idxes) + step_data["is_first"] = copy.deepcopy(step_data["dones"]) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + if cfg.dry_run and buffer_type == "episode": + dones = np.ones_like(dones) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + + for k in obs_keys: # [N_envs, N_obs] + step_data[k] = real_next_obs[k][np.newaxis] + + # Next_obs becomes the new obs + obs = next_obs + + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + reset_data["is_first"] = np.ones_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + # Reset internal agent states + player.init_states(dones_idxes) updates_before_training -= 1 # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task.module player.actor_type = "task" n_samples = ( cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps @@ -484,7 +483,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") diff --git a/sheeprl/algos/p2e_dv3/agent.py b/sheeprl/algos/p2e_dv3/agent.py index 66547eaa..d19b329e 100644 --- a/sheeprl/algos/p2e_dv3/agent.py +++ b/sheeprl/algos/p2e_dv3/agent.py @@ -14,6 +14,7 @@ from sheeprl.algos.dreamer_v3.agent import build_agent as dv3_build_agent from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights from sheeprl.models.models import MLP +from sheeprl.utils.fabric import get_single_device_fabric # In order to use the hydra.utils.get_class method, in this way the user can # specify in the configs the name of the class without having to know where @@ -109,6 +110,7 @@ def build_agent( unimix=cfg.algo.unimix, ) + single_device_fabric = get_single_device_fabric(fabric) critics_exploration = {} intrinsic_critics = 0 for k, v in cfg.algo.critics_exploration.items(): @@ -126,9 +128,11 @@ def build_agent( flatten_dim=None, layer_args={"bias": not critic_cfg.layer_norm}, norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, - norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] - if critic_cfg.layer_norm - else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), ), } critics_exploration[k]["module"].apply(init_weights) @@ -140,6 +144,9 @@ def build_agent( critics_exploration[k]["target_module"] = copy.deepcopy(critics_exploration[k]["module"].module) if critics_exploration_state: critics_exploration[k]["target_module"].load_state_dict(critics_exploration_state[k]["target_module"]) + critics_exploration[k]["target_module"] = single_device_fabric.setup_module( + critics_exploration[k]["target_module"] + ) if intrinsic_critics == 0: raise RuntimeError("You must specify at least one intrinsic critic (`reward_type='intrinsic'`)") diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 918f42aa..93723c00 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -835,26 +835,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.algo.actor.cls.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.as_tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.algo.actor.cls.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.as_tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: preprocessed_obs = {} for k, v in obs.items(): preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) @@ -872,68 +872,70 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) - rb.add(step_data, validate_args=cfg.buffer.validate_args) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + + step_data["is_first"] = np.zeros_like(step_data["dones"]) + if "restart_on_exception" in infos: + for i, agent_roe in enumerate(infos["restart_on_exception"]): + if agent_roe and not dones[i]: + last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size + rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["dones"][last_inserted_idx] + ) + rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["is_first"][last_inserted_idx] + ) + step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v - step_data["is_first"] = np.zeros_like(step_data["dones"]) - if "restart_on_exception" in infos: - for i, agent_roe in enumerate(infos["restart_on_exception"]): - if agent_roe and not dones[i]: - last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) - rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( - rb.buffer[i]["is_first"][last_inserted_idx] - ) - step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: - step_data[k] = next_obs[k][np.newaxis] - - # next_obs becomes the new obs - obs = next_obs - - rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) - step_data["rewards"] = clip_rewards_fn(rewards) - - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} for k in obs_keys: - reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - - # Reset already inserted step data - step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) - step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) - player.init_states(dones_idxes) + step_data[k] = next_obs[k][np.newaxis] + + # next_obs becomes the new obs + obs = next_obs + + rewards = rewards.reshape((1, cfg.env.num_envs, -1)) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards) + + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = step_data["rewards"][:, dones_idxes] + reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + + # Reset already inserted step data + step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) + step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) + player.init_states(dones_idxes) updates_before_training -= 1 @@ -1080,7 +1082,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index bd20456c..36e9691c 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -279,10 +279,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size - # Measure environment interaction time: this considers both the model forward - # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): + with torch.inference_mode(): + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): preprocessed_obs = {} for k, v in obs.items(): preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) @@ -300,75 +300,76 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) - rb.add(step_data, validate_args=cfg.buffer.validate_args) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) + + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) + + step_data["is_first"] = np.zeros_like(step_data["dones"]) + if "restart_on_exception" in infos: + for i, agent_roe in enumerate(infos["restart_on_exception"]): + if agent_roe and not dones[i]: + last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size + rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["dones"][last_inserted_idx] + ) + rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["is_first"][last_inserted_idx] + ) + step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) + + if cfg.metric.log_level > 0 and "final_info" in infos: + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + if aggregator and not aggregator.disabled: + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + + # Save the real next observation + real_next_obs = copy.deepcopy(next_obs) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) - - step_data["is_first"] = np.zeros_like(step_data["dones"]) - if "restart_on_exception" in infos: - for i, agent_roe in enumerate(infos["restart_on_exception"]): - if agent_roe and not dones[i]: - last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) - rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( - rb.buffer[i]["is_first"][last_inserted_idx] - ) - step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) - - if cfg.metric.log_level > 0 and "final_info" in infos: - for i, agent_ep_info in enumerate(infos["final_info"]): - if agent_ep_info is not None: - ep_rew = agent_ep_info["episode"]["r"] - ep_len = agent_ep_info["episode"]["l"] - if aggregator and not aggregator.disabled: - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - - # Save the real next observation - real_next_obs = copy.deepcopy(next_obs) - if "final_observation" in infos: - for idx, final_obs in enumerate(infos["final_observation"]): - if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - for k in obs_keys: - step_data[k] = next_obs[k][np.newaxis] - - # next_obs becomes the new obs - obs = next_obs - - rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) - step_data["rewards"] = clip_rewards_fn(rewards) - - dones_idxes = dones.nonzero()[0].tolist() - reset_envs = len(dones_idxes) - if reset_envs > 0: - reset_data = {} for k in obs_keys: - reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) - reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) - reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) - rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - - # Reset already inserted step data - step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) - step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) - player.init_states(dones_idxes) + step_data[k] = next_obs[k][np.newaxis] + + # next_obs becomes the new obs + obs = next_obs + + rewards = rewards.reshape((1, cfg.env.num_envs, -1)) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards) + + dones_idxes = dones.nonzero()[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = {} + for k in obs_keys: + reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = step_data["rewards"][:, dones_idxes] + reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) + + # Reset already inserted step data + step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) + step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) + player.init_states(dones_idxes) updates_before_training -= 1 # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task.module player.actor_type = "task" local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, @@ -487,7 +488,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index c078ed2c..3d6fc985 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -218,6 +218,7 @@ def test_ppo_recurrent(standard_args, start_time): "algo.per_rank_batch_size=1", "algo.per_rank_sequence_length=2", "algo.update_epochs=2", + "fabric.precision=32", f"root_dir={root_dir}", f"run_name={run_name}", ] From ade81ce28f5b37e7cfcbd751a072653cc5705f4e Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 27 Mar 2024 13:58:42 +0100 Subject: [PATCH 23/26] Wrap target critics with a single-device fabric --- sheeprl/algos/droq/agent.py | 9 ++++++--- sheeprl/algos/sac/agent.py | 9 ++++++--- sheeprl/algos/sac_ae/agent.py | 7 ++++++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py index 1c89554f..959a56c3 100644 --- a/sheeprl/algos/droq/agent.py +++ b/sheeprl/algos/droq/agent.py @@ -11,6 +11,7 @@ from sheeprl.algos.sac.agent import SACActor from sheeprl.models.models import MLP +from sheeprl.utils.fabric import get_single_device_fabric LOG_STD_MAX = 2 LOG_STD_MIN = -5 @@ -242,8 +243,10 @@ def build_agent( agent.load_state_dict(agent_state) agent.actor = fabric.setup_module(agent.actor) agent.critics = [fabric.setup_module(critic) for critic in agent.critics] - agent.qfs_target = nn.ModuleList( - [_FabricModule(target, precision=fabric._precision) for target in agent._qfs_target] - ) + + # Wrap the target q-functions with a single-device fabric. This let the target q-functions + # to be on the same device as the agent and to run with the same precision + fabric_player = get_single_device_fabric(fabric) + agent.qfs_target = nn.ModuleList([fabric_player.setup_module(target) for target in agent.qfs_target]) return agent diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index bb2cecbe..327565e6 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -11,6 +11,7 @@ from torch import Tensor from sheeprl.models.models import MLP +from sheeprl.utils.fabric import get_single_device_fabric LOG_STD_MAX = 2 LOG_STD_MIN = -5 @@ -299,8 +300,10 @@ def build_agent( agent.load_state_dict(agent_state) agent.actor = fabric.setup_module(agent.actor) agent.critics = [fabric.setup_module(critic) for critic in agent.critics] - agent.qfs_target = nn.ModuleList( - [_FabricModule(target, precision=fabric._precision) for target in agent._qfs_target] - ) + + # Wrap the target q-functions with a single-device fabric. This let the target q-functions + # to be on the same device as the agent and to run with the same precision + fabric_player = get_single_device_fabric(fabric) + agent.qfs_target = nn.ModuleList([fabric_player.setup_module(target) for target in agent.qfs_target]) return agent diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index d51768af..f66115f1 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -13,6 +13,7 @@ from sheeprl.algos.sac_ae.utils import weight_init from sheeprl.models.models import CNN, MLP, DeCNN, MultiDecoder, MultiEncoder +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.model import cnn_forward LOG_STD_MAX = 2 @@ -562,6 +563,10 @@ def build_agent( decoder = fabric.setup_module(decoder) agent.actor = fabric.setup_module(agent.actor) agent.critic = fabric.setup_module(agent.critic) - agent.critic_target = _FabricModule(agent.critic_target, precision=fabric._precision) + + # Wrap the target critic with a single-device fabric. This lets the target critic + # to be on the same device as the agent and to run with the same precision + fabric_player = get_single_device_fabric(fabric) + agent.critic_target = fabric_player.setup_module(agent.critic_target) return agent, encoder, decoder From 825ed67ea2afa5d60fe795e275b3ff02f1c1a9b6 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 27 Mar 2024 17:20:02 +0100 Subject: [PATCH 24/26] import annotations from future --- sheeprl/algos/a2c/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sheeprl/algos/a2c/utils.py b/sheeprl/algos/a2c/utils.py index 430cdc77..23dd4bf6 100644 --- a/sheeprl/algos/a2c/utils.py +++ b/sheeprl/algos/a2c/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Dict import torch From 2c7286ff53f07ffe299908691e55fcd0ff8d33c7 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 27 Mar 2024 18:33:40 +0100 Subject: [PATCH 25/26] Fix P2E actor task during test --- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 1 + sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 2 ++ sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 1 + sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 2 ++ sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 1 + sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 2 ++ 6 files changed, 9 insertions(+) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 3e0d7bb3..99086504 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -835,6 +835,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: + player.actor = actor_task player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 754fda8c..63947b7d 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -345,6 +345,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": + player.actor = actor_task player.actor_type = "task" with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(cfg.algo.per_rank_gradient_steps): @@ -450,6 +451,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: + player.actor = actor_task player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index d17c1ec1..11c13fe2 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -1000,6 +1000,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: + player.actor = actor_task player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 8bfa7d86..b0a18af7 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -368,6 +368,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": + player.actor = actor_task player.actor_type = "task" n_samples = ( cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps @@ -483,6 +484,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: + player.actor = actor_task player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 93723c00..a33397dc 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -1082,6 +1082,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: + player.actor = actor_task player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 36e9691c..aae93178 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -370,6 +370,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": + player.actor = actor_task player.actor_type = "task" local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, @@ -488,6 +489,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: + player.actor = actor_task player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") From 1b2dd2a6116f20684afd486425340defdbcbfd67 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 27 Mar 2024 18:56:37 +0100 Subject: [PATCH 26/26] Wrap actor with single-device fabric --- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 4 +++- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 8 ++++++-- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 4 +++- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 8 ++++++-- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 4 +++- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 9 +++++++-- 6 files changed, 28 insertions(+), 9 deletions(-) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 99086504..9ad04faa 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -24,6 +24,7 @@ from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -835,8 +836,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + fabric_player = get_single_device_fabric(fabric) + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "zero-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 63947b7d..6da54a91 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -19,6 +19,7 @@ from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -31,6 +32,9 @@ @register_algorithm() def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): + # Single-device fabric object + fabric_player = get_single_device_fabric(fabric) + device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -345,8 +349,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(cfg.algo.per_rank_gradient_steps): sample = rb.sample_tensors( @@ -451,8 +455,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "few-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 11c13fe2..54fad041 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -24,6 +24,7 @@ from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -1000,8 +1001,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + fabric_player = get_single_device_fabric(fabric) + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "zero-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index b0a18af7..d00f8c69 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -19,6 +19,7 @@ from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -31,6 +32,9 @@ @register_algorithm() def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): + # Single-device fabric object + fabric_player = get_single_device_fabric(fabric) + device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -368,8 +372,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) n_samples = ( cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps ) @@ -484,8 +488,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "few-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index a33397dc..106cb655 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -28,6 +28,7 @@ TwoHotEncodingDistribution, ) from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -1082,8 +1083,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + fabric_player = get_single_device_fabric(fabric) + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "zero-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index aae93178..b5b482f8 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -17,6 +17,7 @@ from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -26,6 +27,9 @@ @register_algorithm() def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): + # Single-device fabric object + fabric_player = get_single_device_fabric(fabric) + device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -370,8 +374,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -487,10 +491,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) envs.close() + # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "few-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: