Skip to content

Commit

Permalink
feat: added prepare obs to all the algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi committed Apr 16, 2024
1 parent 31a8e9b commit c9608d4
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 114 deletions.
22 changes: 10 additions & 12 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,34 @@

from typing import Any, Dict

import numpy as np
import torch
from lightning import Fabric
from torch import Tensor

from sheeprl.algos.ppo.agent import PPOPlayer
from sheeprl.utils.env import make_env

AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *args, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v[np.newaxis]).to(fabric.device).float() for k, v in obs.items()}
return torch_obs


@torch.no_grad()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
torch_obs = torch_obs.float()
obs[k] = torch_obs

while not done:
# Convert observations to tensors
obs = prepare_obs(fabric, o)

# Act greedly through the environment
actions = agent.get_actions(obs, greedy=True)
if agent.actor.is_continuous:
Expand All @@ -37,12 +41,6 @@ def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
31 changes: 17 additions & 14 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from lightning import Fabric
Expand Down Expand Up @@ -101,6 +102,18 @@ def compute_lambda_values(
return torch.cat(list(reversed(lv)), dim=0)


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str] = []) -> Dict[str, Tensor]:
torch_obs = {}
for k, v in obs.items():
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).view(1, *v.shape).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k][None, ...] / 255 - 0.5
else:
torch_obs[k] = torch_obs[k][None, ...]

return torch_obs


@torch.no_grad()
def test(
player: "PlayerDV2" | "PlayerDV1",
Expand All @@ -125,32 +138,22 @@ def test(
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
cumulative_rew = 0
device = fabric.device
next_obs = env.reset(seed=cfg.seed)[0]
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
o = env.reset(seed=cfg.seed)[0]
player.num_envs = 1
player.init_states()
while not done:
# Act greedly through the environment
preprocessed_obs = {}
for k, v in next_obs.items():
if k in cfg.algo.cnn_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5
elif k in cfg.algo.mlp_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device)
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)
real_actions = player.get_actions(
preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
o, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
done = done or truncated or cfg.dry_run
cumulative_rew += reward
fabric.print("Test - Reward:", cumulative_rew)
Expand Down
30 changes: 16 additions & 14 deletions sheeprl/algos/dreamer_v3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ def compute_lambda_values(
return ret


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str] = []) -> Dict[str, Tensor]:
torch_obs = {}
for k, v in obs.items():
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).view(1, *v.shape).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k][None, ...] / 255 - 0.5
else:
torch_obs[k] = torch_obs[k][None, ...]

return torch_obs


@torch.no_grad()
def test(
player: "PlayerDV3",
Expand All @@ -101,32 +113,22 @@ def test(
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
cumulative_rew = 0
device = fabric.device
next_obs = env.reset(seed=cfg.seed)[0]
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
o = env.reset(seed=cfg.seed)[0]
player.num_envs = 1
player.init_states()
while not done:
# Act greedly through the environment
preprocessed_obs = {}
for k, v in next_obs.items():
if k in cfg.algo.cnn_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5
elif k in cfg.algo.mlp_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device)
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)
real_actions = player.get_actions(
preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
o, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
done = done or truncated or cfg.dry_run
cumulative_rew += reward
fabric.print("Test - Reward:", cumulative_rew)
Expand Down
31 changes: 12 additions & 19 deletions sheeprl/algos/ppo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,28 @@
MODELS_TO_REGISTER = {"agent"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str]) -> Dict[str, Tensor]:
torch_obs = {}
for k in obs.keys():
torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).unsqueeze(0).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k].reshape(1, -1, *torch_obs[k].shape[-2:]) / 255 - 0.5
return torch_obs


@torch.no_grad()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs

while not done:
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)

# Act greedly through the environment
actions = agent.get_actions(obs, greedy=True)
actions = agent.get_actions(torch_obs, greedy=True)
if agent.actor.is_continuous:
actions = torch.cat(actions, dim=-1)
else:
Expand All @@ -51,15 +53,6 @@ def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
37 changes: 17 additions & 20 deletions sheeprl/algos/ppo_recurrent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import TYPE_CHECKING, Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
from torch import Tensor

from sheeprl.algos.ppo.utils import AGGREGATOR_KEYS as ppo_aggregator_keys
from sheeprl.algos.ppo.utils import MODELS_TO_REGISTER as ppo_models_to_register
Expand All @@ -21,33 +23,36 @@
MODELS_TO_REGISTER = ppo_models_to_register


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str]) -> Dict[str, Tensor]:
torch_obs = {}
with fabric.device:
for k, v in obs.items():
torch_obs[k] = torch.as_tensor(v.copy(), dtype=torch.float32, device=fabric.device)
if k in cnn_keys:
torch_obs[k] = torch_obs[k].view(1, 1, -1, *v.shape[-2:]) / 255 - 0.5
else:
torch_obs[k] = torch_obs[k].view(1, 1, -1)
return torch_obs


@torch.no_grad()
def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
agent.num_envs = 1
o = env.reset(seed=cfg.seed)[0]
with fabric.device:
o = env.reset(seed=cfg.seed)[0]
next_obs = {
k: torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1, *o[k].shape[-2:]) / 255
for k in cfg.algo.cnn_keys.encoder
}
next_obs.update(
{
k: torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1)
for k in cfg.algo.mlp_keys.encoder
}
)
state = (
torch.zeros(1, 1, agent.rnn_hidden_size, device=fabric.device),
torch.zeros(1, 1, agent.rnn_hidden_size, device=fabric.device),
)
actions = torch.zeros(1, 1, sum(agent.actions_dim), device=fabric.device)
while not done:
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)
# Act greedly through the environment
actions, state = agent.get_actions(next_obs, actions, state, greedy=True)
actions, state = agent.get_actions(torch_obs, actions, state, greedy=True)
if agent.actor.is_continuous:
real_actions = torch.cat(actions, -1)
actions = torch.cat(actions, dim=-1).view(1, 1, -1)
Expand All @@ -59,14 +64,6 @@ def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_d
o, reward, done, truncated, info = env.step(real_actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
with fabric.device:
next_obs = {
k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1, *o[k].shape[-2:]) / 255
for k in cfg.algo.cnn_keys.encoder
}
next_obs.update(
{k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1) for k in cfg.algo.mlp_keys.encoder}
)

if cfg.dry_run:
done = True
Expand Down
25 changes: 12 additions & 13 deletions sheeprl/algos/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import TYPE_CHECKING, Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor

from sheeprl.algos.sac.agent import SACPlayer, build_agent
from sheeprl.utils.env import make_env
Expand All @@ -26,31 +28,28 @@
MODELS_TO_REGISTER = {"agent"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *args, **kwargs) -> Tensor:
with fabric.device:
torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in obs.keys()], dim=-1)
return torch_obs.unsqueeze(0)


@torch.no_grad()
def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
actor.eval()
done = False
cumulative_rew = 0
with fabric.device:
o = env.reset(seed=cfg.seed)[0]
next_obs = torch.cat(
[torch.as_tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1
).unsqueeze(
0
) # [N_envs, N_obs]
o = env.reset(seed=cfg.seed)[0]
while not done:
# Act greedly through the environment
action = actor.get_actions(next_obs, greedy=True)
torch_obs = prepare_obs(fabric, o)
action = actor.get_actions(torch_obs, greedy=True)

# Single environment step
next_obs, reward, done, truncated, info = env.step(action.cpu().numpy().reshape(env.action_space.shape))
o, reward, done, truncated, info = env.step(action.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
with fabric.device:
next_obs = torch.cat(
[torch.as_tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1
)

if cfg.dry_run:
done = True
Expand Down
Loading

0 comments on commit c9608d4

Please sign in to comment.