Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/actions as obs #291

Merged
merged 14 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(
fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs
)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.stack(actions, -1).cpu().numpy()
Expand Down Expand Up @@ -304,7 +306,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.inference_mode():
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
Expand Down
10 changes: 6 additions & 4 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict
from typing import Any, Dict, Sequence

import numpy as np
import torch
Expand All @@ -13,8 +13,10 @@
AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v.copy()).to(fabric.device).float().reshape(num_envs, -1) for k, v in obs.items()}
def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, mlp_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(obs[k].copy()).to(fabric.device).float().reshape(num_envs, -1) for k in mlp_keys}
return torch_obs


Expand All @@ -28,7 +30,7 @@ def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):

while not done:
# Convert observations to tensors
torch_obs = prepare_obs(fabric, obs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder)

# Act greedly through the environment
actions = agent.get_actions(torch_obs, greedy=True)
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
else:
with torch.inference_mode():
# Sample an action given the observation received by the environment
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = player(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
else:
# Sample an action given the observation received by the environment
with torch.inference_mode():
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = player(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def player(
actions = envs.action_space.sample()
else:
# Sample an action given the observation received by the environment
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = actor(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
MODELS_TO_REGISTER = {"agent"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Tensor:
def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, mlp_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Tensor:
with fabric.device:
torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in obs.keys()], dim=-1)
torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in mlp_keys], dim=-1)
return torch_obs.reshape(num_envs, -1)


Expand All @@ -43,7 +45,7 @@ def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
obs = env.reset(seed=cfg.seed)[0]
while not done:
# Act greedly through the environment
torch_obs = prepare_obs(fabric, obs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder)
action = actor.get_actions(torch_obs, greedy=True)

# Single environment step
Expand Down
2 changes: 2 additions & 0 deletions sheeprl/configs/env/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ grayscale: False
clip_rewards: False
capture_video: True
frame_stack_dilation: 1
action_stack: -1
action_stack_dilation: 1
max_episode_steps: null
reward_as_observation: False
wrapper: ???
64 changes: 64 additions & 0 deletions sheeprl/envs/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import time
from collections import deque
Expand Down Expand Up @@ -251,3 +253,65 @@ def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
if len(frame.shape) == 3 and frame.shape[-1] == 1:
frame = frame.repeat(3, axis=-1)
return frame


class ActionsAsObservationWrapper(gym.Wrapper):
def __init__(self, env: Env, num_stack: int, dilation: int = 1):
super().__init__(env)
if num_stack < 1:
raise ValueError(
"The number of actions to the `action_stack` observation "
f"must be greater or equal than 1, got: {num_stack}"
)
if dilation < 1:
raise ValueError(f"The actions stack dilation argument must be greater than zero, got: {dilation}")
self._num_stack = num_stack
self._dilation = dilation
self._actions = deque(maxlen=num_stack * dilation)
self._is_continuous = isinstance(self.env.action_space, gym.spaces.Box)
self._is_multidiscrete = isinstance(self.env.action_space, gym.spaces.MultiDiscrete)
self.observation_space = copy.deepcopy(self.env.observation_space)
if self._is_continuous:
self._action_shape = self.env.action_space.shape[0]
low = np.resize(self.env.action_space.low, self._action_shape * num_stack)
high = np.resize(self.env.action_space.high, self._action_shape * num_stack)
elif self._is_multidiscrete:
low = 0
high = max(self.env.action_space.nvec) - 1
self._action_shape = self.env.action_space.nvec.shape[0]
else:
low = 0
high = 1 # one-hot encoding
self._action_shape = self.env.action_space.n
self.observation_space["action_stack"] = gym.spaces.Box(
low=low, high=high, shape=(self._action_shape * num_stack,), dtype=np.float32
)

def step(self, action: Any) -> Tuple[Any | SupportsFloat | bool | Dict[str, Any]]:
self._actions.append(action)
obs, reward, done, truncated, info = super().step(action)
obs["action_stack"] = self._get_actions_stack()
return obs, reward, done, truncated, info

def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None) -> Tuple[Any | Dict[str, Any]]:
obs, info = super().reset(seed=seed, options=options)
self._actions.clear()
if self._is_multidiscrete or self._is_continuous:
michele-milesi marked this conversation as resolved.
Show resolved Hide resolved
[self._actions.append(np.zeros((self._action_shape,))) for _ in range(self._num_stack * self._dilation)]
else:
[self._actions.append(0) for _ in range(self._num_stack * self._dilation)]
obs["action_stack"] = self._get_actions_stack()
return obs, info

def _get_actions_stack(self) -> np.ndarray:
actions_stack = list(self._actions)[self._dilation - 1 :: self._dilation]
if self._is_continuous or self._is_multidiscrete:
actions = np.concatenate(actions_stack, axis=0)
else:
action_list = []
for action in actions_stack:
one_hot_action = np.zeros(self.env.action_space.n)
one_hot_action[action] = 1
action_list.append(one_hot_action)
actions = np.concatenate(action_list, axis=0)
return actions.astype(np.float32)
4 changes: 4 additions & 0 deletions sheeprl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sheeprl.envs.wrappers import (
ActionRepeat,
ActionsAsObservationWrapper,
FrameStack,
GrayscaleRenderWrapper,
MaskVelocityWrapper,
Expand Down Expand Up @@ -207,6 +208,9 @@ def transform_obs(obs: Dict[str, Any]):
)
env = FrameStack(env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation)

if cfg.env.action_stack > 0 and "diambra" not in cfg.env.wrapper._target_:
env = ActionsAsObservationWrapper(env, cfg.env.action_stack, cfg.env.action_stack_dilation)
michele-milesi marked this conversation as resolved.
Show resolved Hide resolved

if cfg.env.reward_as_observation:
env = RewardAsObservationWrapper(env)

Expand Down
Loading