-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: added actions as obs wrapper * fix: actions shape * fix: action_stack key * feat: added controls * fix: multi-discrete action stack * test: update * feat: added mlp_keys to prepare obs of sac, droq and a2c * feat: added from __future__ import annotations * feat: added noop + test * feat: update tests + added controls in wrapper + update docs * fix: typo
- Loading branch information
1 parent
28d14ba
commit 878620a
Showing
12 changed files
with
244 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Actions as Observations Wrapper | ||
In this how-to, some indications are given on how to use the Actions as Observations Wrapper. | ||
|
||
When you want to add the last `n` actions to the observations, you must specify three parameters in the [`./configs/env/default.yaml`](../sheeprl/configs/env/default.yaml) file: | ||
- `actions_as_observation.num_stack` (integer greater than 0): The number of actions to add to the observations. | ||
- `actions_as_observation.dilation` (integer greater than 0): The dilation (number of steps) between one action and the next one. | ||
- `actions_as_observation.noop` (integer or float or list of integer): The noop action to use when resetting the environment, the buffer is filled with this action. Every environment has its own NOOP action, it is strongly recommended to use that action for the correct learning of the algorithm. | ||
|
||
## NOOP Parameter | ||
The NOOP parameter must be: | ||
- An integer for discrete action spaces | ||
- A float for continuous action spaces | ||
- A list of integers for multi-discrete action spaces: the length of the list must be equal to the number of actions in the environment. | ||
|
||
Each environment has its own NOOP action, usually it is specified in the documentation. Below we reported the list of noop actions of the environments supported in SheepRL: | ||
- MuJoCo (both gymnasium and DMC) environments: `0.0`. | ||
- Atari environments: `0`. | ||
- Crafter: `0`. | ||
- MineRL: `0`. | ||
- MineDojo: `[0, 0, 0]`. | ||
- Super Mario Bros: `0`. | ||
- Diambra: | ||
- Discrete: `0`. | ||
- Multi-discrete: `[0, 0]`. | ||
- Box2D (gymnasium): | ||
- Discrete: `0`. | ||
- Continuous: `0.0`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,102 @@ | ||
import gymnasium as gym | ||
import numpy as np | ||
import pytest | ||
|
||
from sheeprl.envs.wrappers import MaskVelocityWrapper | ||
from sheeprl.envs.dummy import ContinuousDummyEnv, DiscreteDummyEnv, MultiDiscreteDummyEnv | ||
from sheeprl.envs.wrappers import ActionsAsObservationWrapper, MaskVelocityWrapper | ||
|
||
ENVIRONMENTS = { | ||
"discrete_dummy": DiscreteDummyEnv, | ||
"multidiscrete_dummy": MultiDiscreteDummyEnv, | ||
"continuous_dummy": ContinuousDummyEnv, | ||
} | ||
|
||
|
||
def test_mask_velocities_fail(): | ||
with pytest.raises(NotImplementedError): | ||
env = gym.make("CarRacing-v2") | ||
env = MaskVelocityWrapper(env) | ||
|
||
|
||
@pytest.mark.parametrize("num_stack", [1, 4, 8]) | ||
@pytest.mark.parametrize("dilation", [1, 2, 4]) | ||
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) | ||
def test_actions_as_observation_wrapper(env_id: str, num_stack, dilation): | ||
env = ENVIRONMENTS[env_id]() | ||
if isinstance(env.action_space, gym.spaces.MultiDiscrete): | ||
noop = [0, 0] | ||
else: | ||
noop = 0 | ||
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=dilation) | ||
|
||
o = env.reset()[0] | ||
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape) | ||
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape): | ||
assert d1 == d2 | ||
|
||
for _ in range(64): | ||
o = env.step(env.action_space.sample())[0] | ||
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape) | ||
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape): | ||
assert d1 == d2 | ||
|
||
|
||
@pytest.mark.parametrize("num_stack", [-1, 0]) | ||
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) | ||
def test_actions_as_observation_wrapper_invalid_num_stack(env_id, num_stack): | ||
env = ENVIRONMENTS[env_id]() | ||
if isinstance(env.action_space, gym.spaces.MultiDiscrete): | ||
noop = [0, 0] | ||
else: | ||
noop = 0 | ||
with pytest.raises(ValueError, match="The number of actions to the"): | ||
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=3) | ||
|
||
|
||
@pytest.mark.parametrize("dilation", [-1, 0]) | ||
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) | ||
def test_actions_as_observation_wrapper_invalid_dilation(env_id, dilation): | ||
env = ENVIRONMENTS[env_id]() | ||
if isinstance(env.action_space, gym.spaces.MultiDiscrete): | ||
noop = [0, 0] | ||
else: | ||
noop = 0 | ||
with pytest.raises(ValueError, match="The actions stack dilation argument must be greater than zero"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=dilation) | ||
|
||
|
||
@pytest.mark.parametrize("noop", [set([0, 0, 0]), "this is an invalid type", np.array([0, 0, 0])]) | ||
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) | ||
def test_actions_as_observation_wrapper_invalid_noop_type(env_id, noop): | ||
env = ENVIRONMENTS[env_id]() | ||
with pytest.raises(ValueError, match="The noop action must be an integer or float or list"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) | ||
|
||
|
||
def test_actions_as_observation_wrapper_invalid_noop_continuous_type(): | ||
env = ContinuousDummyEnv() | ||
with pytest.raises(ValueError, match="The noop actions must be a float for continuous action spaces"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=[0, 0, 0], dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("noop", [[0, 0, 0], 0.0]) | ||
def test_actions_as_observation_wrapper_invalid_noop_discrete_type(noop): | ||
env = DiscreteDummyEnv() | ||
with pytest.raises(ValueError, match="The noop actions must be an integer for discrete action spaces"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("noop", [0, 0.0]) | ||
def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_type(noop): | ||
env = MultiDiscreteDummyEnv() | ||
with pytest.raises(ValueError, match="The noop actions must be a list for multi-discrete action spaces"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("noop", [[0], [0, 0, 0]]) | ||
def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_n_actions(noop): | ||
env = MultiDiscreteDummyEnv() | ||
with pytest.raises( | ||
RuntimeError, match="The number of noop actions must be equal to the number of actions of the environment" | ||
): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) |