Skip to content

Commit

Permalink
[Test] More comprehensive tests for auto_spec
Browse files Browse the repository at this point in the history
ghstack-source-id: 75352490436fd706af3d36f9b8016e80a8a3f46a
Pull Request resolved: #2640
  • Loading branch information
vmoens committed Dec 12, 2024
1 parent ef5a37d commit 6c7d233
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
10 changes: 7 additions & 3 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,14 +1931,18 @@ def __init__(self):
tensor=Unbounded(3),
non_tensor=NonTensor(shape=()),
)
self._saved_obs_spec = self.observation_spec.clone()
self.state_spec = Composite(
non_tensor=NonTensor(shape=()),
)
self._saved_state_spec = self.state_spec.clone()
self.reward_spec = Unbounded(1)
self._saved_full_reward_spec = self.full_reward_spec.clone()
self.action_spec = Unbounded(1)
self._saved_full_action_spec = self.full_action_spec.clone()

def _reset(self, tensordict):
data = self.observation_spec.zero()
data = self._saved_obs_spec.zero()
data.set_non_tensor("non_tensor", 0)
data.update(self.full_done_spec.zero())
return data
Expand All @@ -1947,10 +1951,10 @@ def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
data = self.observation_spec.zero()
data = self._saved_obs_spec.zero()
data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1)
data.update(self.full_done_spec.zero())
data.update(self.full_reward_spec.zero())
data.update(self._saved_full_reward_spec.zero())
return data

def _set_seed(self, seed: Optional[int]):
Expand Down
11 changes: 8 additions & 3 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3553,8 +3553,13 @@ def test_single_env_spec():
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))


def test_auto_spec():
env = CountingEnv()
@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata])
def test_auto_spec(env_type):
if env_type is EnvWithMetadata:
obs_vals = ["tensor", "non_tensor"]
else:
obs_vals = "observation"
env = env_type()
td = env.reset()

policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
Expand All @@ -3577,7 +3582,7 @@ def test_auto_spec():
shape=env.full_state_spec.shape, device=env.full_state_spec.device
)
env._action_keys = ["action"]
env.auto_specs_(policy, tensordict=td.copy())
env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals)
env.check_env_specs(tensordict=td.copy())


Expand Down

0 comments on commit 6c7d233

Please sign in to comment.