Skip to content

Commit

Permalink
Fix test to handle nested observation dicts (#1172)
Browse files Browse the repository at this point in the history
  • Loading branch information
dm-ackerman authored Mar 22, 2024
1 parent b19c02c commit 0cdf49e
Showing 1 changed file with 45 additions and 7 deletions.
52 changes: 45 additions & 7 deletions pettingzoo/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import warnings
from collections import defaultdict
from typing import Any

import gymnasium
import numpy as np
Expand Down Expand Up @@ -383,6 +384,46 @@ def test_rewards_terminations_truncations(env, agent_0):
test_reward(env.rewards[agent])


def _test_observation_space_compatibility(
expected: gymnasium.spaces.Space[Any],
seen: gymnasium.spaces.Space[Any] | dict,
recursed_keys: list[str],
) -> None:
"""Ensure observation's dtypes are same as in observation_space.
This tests that the dtypes of the spaces are the same.
The function will recursively check observation dicts to ensure that
all components have the same dtype as declared in the observation space.
Args:
expected: Observation space that is expected.
seen: The observation actually seen.
recursed_keys: A list of all the dict keys that led to the current
observations. This enables a more helpful error message if
an assert fails. The initial call should have an empty list.
"""
if isinstance(expected, gymnasium.spaces.Dict):
for key in expected.keys():
if not recursed_keys and key != "observation":
# For the top level, we only care about the 'observation' key.
continue
# We know a dict is expected. Anything else is an error.
assert isinstance(
seen, dict
), f"observation at [{']['.join(recursed_keys)}] is {seen.dtype}, but expected dict."

# note: a previous test (expected.contains(seen)) ensures that
# the two dicts have the same keys.
_test_observation_space_compatibility(
expected[key], seen[key], recursed_keys + [key]
)
else:
# done recursing, now the actual space types should match
assert (
expected.dtype == seen.dtype
), f"dtype for observation at [{']['.join(recursed_keys)}] is {seen.dtype}, but observation space specifies {expected.dtype}."


def play_test(env, observation_0, num_cycles):
"""
plays through environment and does dynamic checks to make
Expand Down Expand Up @@ -466,13 +507,10 @@ def play_test(env, observation_0, num_cycles):
prev_observe
), "Out of bounds observation: " + str(prev_observe)

if isinstance(env.observation_space(agent), gymnasium.spaces.Box):
assert env.observation_space(agent).dtype == prev_observe.dtype
elif isinstance(env.observation_space(agent), gymnasium.spaces.Dict):
assert (
env.observation_space(agent)["observation"].dtype
== prev_observe["observation"].dtype
)
_test_observation_space_compatibility(
env.observation_space(agent), prev_observe, recursed_keys=[]
)

test_observation(prev_observe, observation_0, str(env.unwrapped))
if not isinstance(env.infos[env.agent_selection], dict):
warnings.warn(
Expand Down

0 comments on commit 0cdf49e

Please sign in to comment.