-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tests/Added tests for env wrappers (#296)
* Tests/Added tests for env wrappers * Tests/Deleted unnecessary file * Tests/Minor fixes to previous commit --------- Co-authored-by: Locatelli Alex Giannino <[email protected]> Co-authored-by: Michele Milesi <[email protected]>
- Loading branch information
1 parent
62b3da0
commit dee8c80
Showing
4 changed files
with
331 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from collections import deque | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import pytest | ||
|
||
from sheeprl.envs.dummy import ContinuousDummyEnv, DiscreteDummyEnv, MultiDiscreteDummyEnv | ||
from sheeprl.envs.wrappers import ActionsAsObservationWrapper | ||
|
||
ENVIRONMENTS = { | ||
"discrete_dummy": DiscreteDummyEnv, | ||
"multidiscrete_dummy": MultiDiscreteDummyEnv, | ||
"continuous_dummy": ContinuousDummyEnv, | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("num_stack", [1, 4, 8]) | ||
@pytest.mark.parametrize("dilation", [1, 2, 4]) | ||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
def test_actions_as_observation_wrapper(env_id: str, num_stack, dilation): | ||
env = ENVIRONMENTS[env_id]() | ||
if isinstance(env.action_space, gym.spaces.MultiDiscrete): | ||
noop = [0, 0] | ||
else: | ||
noop = 0 | ||
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=dilation) | ||
|
||
o = env.reset()[0] | ||
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape) | ||
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape): | ||
assert d1 == d2 | ||
|
||
actions = [] | ||
for _ in range(8): | ||
action = env.action_space.sample() | ||
actions.append(action) | ||
o = env.step(action)[0] | ||
|
||
# Ensure the shapes match | ||
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape) | ||
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape): | ||
assert d1 == d2 | ||
|
||
expected_actions = deque(maxlen=num_stack * dilation) | ||
if len(actions) < num_stack * dilation: | ||
for _ in range(num_stack * dilation - len(actions)): | ||
expected_actions.append(env.noop) | ||
for past_action in actions[-(num_stack * dilation) :]: | ||
if isinstance(env.action_space, gym.spaces.Box): | ||
expected_actions.append(past_action) | ||
elif isinstance(env.action_space, gym.spaces.MultiDiscrete): | ||
one_hot_actions = [] | ||
for act, n in zip(past_action, env.action_space.nvec): | ||
one_hot_actions.append(np.zeros((n,), dtype=np.float32)) | ||
one_hot_actions[-1][act] = 1.0 | ||
expected_actions.append(np.concatenate(one_hot_actions, axis=-1)) | ||
else: | ||
one_hot_action = np.zeros((env.action_space.n,), dtype=np.float32) | ||
one_hot_action[past_action] = 1.0 | ||
expected_actions.append(one_hot_action) | ||
|
||
expected_actions_stack = list(expected_actions)[dilation - 1 :: dilation] | ||
expected_actions_stack = np.concatenate(expected_actions_stack, axis=-1).astype(np.float32) | ||
|
||
np.testing.assert_array_equal(o["action_stack"], expected_actions_stack) | ||
|
||
|
||
@pytest.mark.parametrize("num_stack", [-1, 0]) | ||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
def test_actions_as_observation_wrapper_invalid_num_stack(env_id, num_stack): | ||
env = ENVIRONMENTS[env_id]() | ||
if isinstance(env.action_space, gym.spaces.MultiDiscrete): | ||
noop = [0, 0] | ||
else: | ||
noop = 0 | ||
with pytest.raises(ValueError, match="The number of actions to the"): | ||
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=3) | ||
|
||
|
||
@pytest.mark.parametrize("dilation", [-1, 0]) | ||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
def test_actions_as_observation_wrapper_invalid_dilation(env_id, dilation): | ||
env = ENVIRONMENTS[env_id]() | ||
if isinstance(env.action_space, gym.spaces.MultiDiscrete): | ||
noop = [0, 0] | ||
else: | ||
noop = 0 | ||
with pytest.raises(ValueError, match="The actions stack dilation argument must be greater than zero"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=dilation) | ||
|
||
|
||
@pytest.mark.parametrize("noop", [set([0, 0, 0]), "this is an invalid type", np.array([0, 0, 0])]) | ||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
def test_actions_as_observation_wrapper_invalid_noop_type(env_id, noop): | ||
env = ENVIRONMENTS[env_id]() | ||
with pytest.raises(ValueError, match="The noop action must be an integer or float or list"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) | ||
|
||
|
||
def test_actions_as_observation_wrapper_invalid_noop_continuous_type(): | ||
env = ContinuousDummyEnv() | ||
with pytest.raises(ValueError, match="The noop actions must be a float for continuous action spaces"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=[0, 0, 0], dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("noop", [[0, 0, 0], 0.0]) | ||
def test_actions_as_observation_wrapper_invalid_noop_discrete_type(noop): | ||
env = DiscreteDummyEnv() | ||
with pytest.raises(ValueError, match="The noop actions must be an integer for discrete action spaces"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("noop", [0, 0.0]) | ||
def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_type(noop): | ||
env = MultiDiscreteDummyEnv() | ||
with pytest.raises(ValueError, match="The noop actions must be a list for multi-discrete action spaces"): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("noop", [[0], [0, 0, 0]]) | ||
def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_n_actions(noop): | ||
env = MultiDiscreteDummyEnv() | ||
with pytest.raises( | ||
RuntimeError, match="The number of noop actions must be equal to the number of actions of the environment" | ||
): | ||
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from sheeprl.envs.dummy import ContinuousDummyEnv, DiscreteDummyEnv, MultiDiscreteDummyEnv | ||
from sheeprl.envs.wrappers import FrameStack | ||
|
||
ENVIRONMENTS = { | ||
"discrete_dummy": DiscreteDummyEnv, | ||
"multidiscrete_dummy": MultiDiscreteDummyEnv, | ||
"continuous_dummy": ContinuousDummyEnv, | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("dilation", [1, 2, 4]) | ||
@pytest.mark.parametrize("num_stack", [1, 2, 3]) | ||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
def test_valid_initialization(env_id, num_stack, dilation): | ||
env = ENVIRONMENTS[env_id]() | ||
|
||
env = FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"], dilation=dilation) | ||
assert env._num_stack == num_stack | ||
assert env._dilation == dilation | ||
assert "rgb" in env._cnn_keys | ||
assert "rgb" in env._frames | ||
|
||
|
||
@pytest.mark.parametrize("num_stack", [-2.4, -1, 0]) | ||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
def test_invalid_num_stack(env_id, num_stack): | ||
env = ENVIRONMENTS[env_id]() | ||
|
||
with pytest.raises(ValueError, match="Invalid value for num_stack, expected a value greater"): | ||
FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"], dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("num_stack", [1, 3, 7]) | ||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
def test_invalid_observation_space(env_id, num_stack): | ||
env = ENVIRONMENTS[env_id](dict_obs_space=False) | ||
|
||
with pytest.raises(RuntimeError, match="Expected an observation space of type gym.spaces.Dict"): | ||
FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"], dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("cnn_keys", [[], None]) | ||
@pytest.mark.parametrize("num_stack", [1, 3, 7]) | ||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
def test_invalid_cnn_keys(env_id, num_stack, cnn_keys): | ||
env = ENVIRONMENTS[env_id]() | ||
|
||
with pytest.raises(RuntimeError, match="Specify at least one valid cnn key"): | ||
FrameStack(env, num_stack=num_stack, cnn_keys=cnn_keys, dilation=2) | ||
|
||
|
||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
@pytest.mark.parametrize("num_stack", [1, 3, 7]) | ||
def test_reset_method(env_id, num_stack): | ||
env = ENVIRONMENTS[env_id]() | ||
|
||
wrapper = FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"]) | ||
obs, _ = wrapper.reset() | ||
|
||
assert "rgb" in obs | ||
assert obs["rgb"].shape == (num_stack, *env.observation_space["rgb"].shape) | ||
|
||
|
||
@pytest.mark.parametrize("num_stack", [1, 2, 5]) | ||
@pytest.mark.parametrize("dilation", [1, 2, 3]) | ||
def test_framestack(num_stack, dilation): | ||
env = DiscreteDummyEnv() | ||
env = FrameStack(env, num_stack, cnn_keys=["rgb"], dilation=dilation) | ||
|
||
# Reset the environment to initialize the frame stack | ||
obs, _ = env.reset() | ||
|
||
for step in range(1, 64): | ||
obs = env.step(None)[0] | ||
|
||
expected_frame = np.stack( | ||
[ | ||
np.full( | ||
env.env.observation_space["rgb"].shape, | ||
max(0, (step - dilation * (num_stack - i - 1))) % 256, | ||
dtype=np.uint8, | ||
) | ||
for i in range(num_stack) | ||
], | ||
axis=0, | ||
) | ||
np.testing.assert_array_equal(obs["rgb"], expected_frame) | ||
|
||
|
||
@pytest.mark.parametrize("env_id", ENVIRONMENTS.keys()) | ||
@pytest.mark.parametrize("num_stack", [1, 3, 7]) | ||
def test_step_method(env_id, num_stack): | ||
env = ENVIRONMENTS[env_id]() | ||
wrapper = FrameStack(env, num_stack=num_stack, cnn_keys=["rgb"]) | ||
wrapper.reset() | ||
action = wrapper.action_space.sample() | ||
obs = wrapper.step(action)[0] | ||
assert "rgb" in obs | ||
assert obs["rgb"].shape == (num_stack, *env.observation_space["rgb"].shape) |
Oops, something went wrong.