Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/player precision plugin #244

Merged
merged 27 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0b9a265
Fix PPO player with precision
belerico Mar 16, 2024
8e5376f
Fix PPO agent to run with the correct precision plugin
belerico Mar 19, 2024
ad6e000
Add detach_actor + fix creation of testing agent
belerico Mar 19, 2024
b9df6a9
FIx SAC to use the correct precision plugin
belerico Mar 19, 2024
4712fcb
Add get_single_devie_fabric method
belerico Mar 21, 2024
db2f021
Fix DrOQ agent to handle fp16
belerico Mar 22, 2024
2caf03b
Fix SACAE agent to handle fp16
belerico Mar 22, 2024
8191e8e
Fix PPO recurrent to handle fp16
belerico Mar 22, 2024
68f6805
Fix Join context
belerico Mar 22, 2024
39513db
Fix PPO evaluate
belerico Mar 22, 2024
491bc56
Fix Dreamer-V1 player to handle fp16
belerico Mar 22, 2024
3294fbd
Dreamer-V1 inference mode
belerico Mar 22, 2024
d3f4703
Fix Dreamer-V2 to handle fp16
belerico Mar 22, 2024
18710b3
Fix Dreamer-V3 to handle fp16
belerico Mar 22, 2024
baba51f
Extract module inside the player
belerico Mar 22, 2024
747d9fc
Fix P2E-DV1 to handle fp16
belerico Mar 22, 2024
a3e4471
Fix p2e-dv2 to handle fp16
belerico Mar 22, 2024
e64bac1
Fix p2e-dv3 to handle fp16
belerico Mar 22, 2024
61d2881
Fix A2C to handle fp16
belerico Mar 22, 2024
7dc8cc4
Add `greedy` flag to actor forward method
belerico Mar 27, 2024
33c9d84
Fix PPORecurrent arg postion
belerico Mar 27, 2024
9f4f26e
Let P2E algos handle fp16
belerico Mar 27, 2024
ade81ce
Wrap target critics with a single-device fabric
belerico Mar 27, 2024
fa8c41e
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Mar 27, 2024
825ed67
import annotations from future
belerico Mar 27, 2024
2c7286f
Fix P2E actor task during test
belerico Mar 27, 2024
1b2dd2a
Wrap actor with single-device fabric
belerico Mar 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 47 additions & 47 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Given that the environment has been created with the `make_env` method, the agent
# forward method must accept as input a dictionary like {"obs1_name": obs1, "obs2_name": obs2, ...}.
# The agent should be able to process both image and vector-like observations.
agent = build_agent(
agent, player = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down Expand Up @@ -224,68 +224,68 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data[k] = next_obs[k][np.newaxis]

for update in range(1, num_updates + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size
with torch.inference_mode():
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
actions, _, values = agent.module(torch_obs)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
else:
real_actions = torch.cat([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))

dones = np.logical_or(done, truncated)
dones = dones.reshape(cfg.env.num_envs, -1)
rewards = rewards.reshape(cfg.env.num_envs, -1)

# Update the step data
step_data["dones"] = dones[np.newaxis]
step_data["values"] = values.cpu().numpy()[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["rewards"] = rewards[np.newaxis]
if cfg.buffer.memmap:
step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape))

# Append data to buffer
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Update the observation and dones
next_obs = {}
for k in obs_keys:
_obs = obs[k]
step_data[k] = _obs[np.newaxis]
next_obs[k] = _obs

if cfg.metric.log_level > 0 and "final_info" in info:
for i, agent_ep_info in enumerate(info["final_info"]):
if agent_ep_info is not None:
ep_rew = agent_ep_info["episode"]["r"]
ep_len = agent_ep_info["episode"]["l"]
if aggregator and "Rewards/rew_avg" in aggregator:
aggregator.update("Rewards/rew_avg", ep_rew)
if aggregator and "Game/ep_len_avg" in aggregator:
aggregator.update("Game/ep_len_avg", ep_len)
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")
# Single environment step
obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))

dones = np.logical_or(done, truncated)
dones = dones.reshape(cfg.env.num_envs, -1)
rewards = rewards.reshape(cfg.env.num_envs, -1)

# Update the step data
step_data["dones"] = dones[np.newaxis]
step_data["values"] = values.cpu().numpy()[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["rewards"] = rewards[np.newaxis]
if cfg.buffer.memmap:
step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape))

# Append data to buffer
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Update the observation and dones
next_obs = {}
for k in obs_keys:
_obs = obs[k]
step_data[k] = _obs[np.newaxis]
next_obs[k] = _obs

if cfg.metric.log_level > 0 and "final_info" in info:
for i, agent_ep_info in enumerate(info["final_info"]):
if agent_ep_info is not None:
ep_rew = agent_ep_info["episode"]["r"]
ep_len = agent_ep_info["episode"]["l"]
if aggregator and "Rewards/rew_avg" in aggregator:
aggregator.update("Rewards/rew_avg", ep_rew)
if aggregator and "Game/ep_len_avg" in aggregator:
aggregator.update("Game/ep_len_avg", ep_len)
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

# Transform the data into PyTorch Tensors
local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy)

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.no_grad():
with torch.inference_mode():
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
next_values = agent.module.get_value(torch_obs)
_, _, next_values = player(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
local_data["values"],
Expand Down Expand Up @@ -351,7 +351,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero:
test(agent.module, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.ppo.utils import log_models
Expand Down
51 changes: 23 additions & 28 deletions sheeprl/algos/a2c/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
from typing import Any, Dict, List, Optional, Sequence, Tuple

import gymnasium
Expand All @@ -10,8 +11,10 @@
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal

from sheeprl.algos.ppo.agent import PPOActor
from sheeprl.models.models import MLP
from sheeprl.utils.distribution import OneHotCategoricalValidateArgs
from sheeprl.utils.fabric import get_single_device_fabric


class MLPEncoder(nn.Module):
Expand Down Expand Up @@ -94,7 +97,7 @@ def __init__(
)

# Actor
self.actor_backbone = MLP(
actor_backbone = MLP(
input_dims=features_dim,
output_dim=None,
hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers,
Expand All @@ -109,19 +112,17 @@ def __init__(
)
if is_continuous:
# Output is a tuple of two elements: mean and log_std, one for every action
self.actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)])
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)])
else:
# Output is a tuple of one element: logits, one for every action
self.actor_heads = nn.ModuleList(
[nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]
)
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim])
self.actor = PPOActor(actor_backbone, actor_heads, is_continuous=is_continuous)

def forward(
self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None
self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None, greedy: bool = False
) -> Tuple[Sequence[Tensor], Tensor, Tensor]:
feat = self.feature_extractor(obs)
out: Tensor = self.actor_backbone(feat)
pre_dist: List[Tensor] = [head(out) for head in self.actor_heads]
pre_dist: List[Tensor] = self.actor(feat)
values = self.critic(feat)
if self.is_continuous:
mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1)
Expand All @@ -132,7 +133,7 @@ def forward(
validate_args=self.distribution_cfg.validate_args,
)
if actions is None:
actions = normal.sample()
actions = normal.mode if greedy else normal.sample()
else:
# always composed by a tuple of one element containing all the
# continuous actions
Expand All @@ -151,30 +152,14 @@ def forward(
OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args)
)
if should_append:
actions.append(actions_dist[-1].sample())
actions.append(actions_dist[-1].mode if greedy else actions_dist[-1].sample())
actions_logprobs.append(actions_dist[-1].log_prob(actions[i]))
return tuple(actions), torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), values

def get_value(self, obs: Dict[str, Tensor]) -> Tensor:
feat = self.feature_extractor(obs)
return self.critic(feat)

def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]:
feat = self.feature_extractor(obs)
out = self.actor_backbone(feat)
pre_dist: List[Tensor] = [head(out) for head in self.actor_heads]
if self.is_continuous:
# Just take the mean of the distribution
return [torch.chunk(pre_dist[0], 2, -1)[0]]
else:
# Take the mode of the distribution
return tuple(
[
OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args).mode
for logits in pre_dist
]
)


def build_agent(
fabric: Fabric,
Expand All @@ -183,7 +168,7 @@ def build_agent(
cfg: Dict[str, Any],
obs_space: gymnasium.spaces.Dict,
agent_state: Optional[Dict[str, Tensor]] = None,
) -> _FabricModule:
) -> Tuple[_FabricModule, _FabricModule]:
agent = A2CAgent(
actions_dim=actions_dim,
obs_space=obs_space,
Expand All @@ -196,6 +181,16 @@ def build_agent(
)
if agent_state:
agent.load_state_dict(agent_state)
player = copy.deepcopy(agent)

# Setup training agent
agent = fabric.setup_module(agent)

return agent
# Setup player agent
fabric_player = get_single_device_fabric(fabric)
player = fabric_player.setup_module(player)

# Tie weights between the agent and the player
for agent_p, player_p in zip(agent.parameters(), player.parameters()):
player_p.data = agent_p.data
return agent, player
2 changes: 1 addition & 1 deletion sheeprl/algos/a2c/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ def evaluate_a2c(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n])
)
# Create the actor and critic models
agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"])
agent, _ = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"])
test(agent, fabric, cfg, log_dir)
10 changes: 7 additions & 3 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from typing import Any, Dict

import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule

from sheeprl.algos.a2c.agent import A2CAgent
from sheeprl.utils.env import make_env
Expand All @@ -10,7 +13,7 @@


@torch.no_grad()
def test(agent: A2CAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
def test(agent: A2CAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
Expand All @@ -25,10 +28,11 @@ def test(agent: A2CAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):

while not done:
# Act greedly through the environment
actions, _, _ = agent(obs, greedy=True)
if agent.is_continuous:
actions = torch.cat(agent.get_greedy_actions(obs), dim=-1)
actions = torch.cat(actions, dim=-1)
else:
actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1)
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

# Single environment step
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
Expand Down
42 changes: 26 additions & 16 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor
from sheeprl.algos.dreamer_v2.agent import MLPDecoder, MLPEncoder
from sheeprl.models.models import MLP, MultiDecoder, MultiEncoder
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.utils import init_weights

# In order to use the hydra.utils.get_class method, in this way the user can
Expand Down Expand Up @@ -221,45 +222,54 @@ class PlayerDV1(nn.Module):
"""The model of the DreamerV1 player.

Args:
encoder (nn.Module): the encoder.
recurrent_model (nn.Module): the recurrent model.
representation_model (nn.Module): the representation model.
actor (nn.Module): the actor.
fabric (Fabric): the fabric object.
encoder (nn.Module| _FabricModule): the encoder.
recurrent_model (nn.Module| _FabricModule): the recurrent model.
representation_model (nn.Module| _FabricModule): the representation model.
actor (nn.Module| _FabricModule): the actor.
actions_dim (Sequence[int]): the dimension of each action.
num_envs (int): the number of environments.
stochastic_size (int): the size of the stochastic state.
recurrent_state_size (int): the size of the recurrent state.
device (torch.device): the device to work on.
actor_type (str, optional): which actor the player is using ('task' or 'exploration').
Default to None.
"""

def __init__(
self,
encoder: nn.Module,
recurrent_model: nn.Module,
representation_model: nn.Module,
actor: nn.Module,
fabric: Fabric,
encoder: nn.Module | _FabricModule,
recurrent_model: nn.Module | _FabricModule,
representation_model: nn.Module | _FabricModule,
actor: nn.Module | _FabricModule,
actions_dim: Sequence[int],
num_envs: int,
stochastic_size: int,
recurrent_state_size: int,
device: torch.device,
actor_type: str | None = None,
) -> None:
super().__init__()
self.encoder = encoder
self.recurrent_model = recurrent_model
self.representation_model = representation_model
self.actor = actor
self.device = device
single_device_fabric = get_single_device_fabric(fabric)
self.encoder = single_device_fabric.setup_module(
getattr(encoder, "module", encoder),
)
self.recurrent_model = single_device_fabric.setup_module(
getattr(recurrent_model, "module", recurrent_model),
)
self.representation_model = single_device_fabric.setup_module(
getattr(representation_model, "module", representation_model)
)
self.actor = single_device_fabric.setup_module(
getattr(actor, "module", actor),
)
self.device = single_device_fabric.device
self.actions_dim = actions_dim
self.stochastic_size = stochastic_size
self.recurrent_state_size = recurrent_state_size
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.init_states()
self.actor_type = actor_type
self.init_states()

def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
"""Initialize the states and the actions for the ended environments.
Expand Down
Loading
Loading