diff --git a/test/mocking_classes.py b/test/mocking_classes.py index b6f4ac7069b..6f666290376 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1931,14 +1931,18 @@ def __init__(self): tensor=Unbounded(3), non_tensor=NonTensor(shape=()), ) + self._saved_obs_spec = self.observation_spec.clone() self.state_spec = Composite( non_tensor=NonTensor(shape=()), ) + self._saved_state_spec = self.state_spec.clone() self.reward_spec = Unbounded(1) + self._saved_full_reward_spec = self.full_reward_spec.clone() self.action_spec = Unbounded(1) + self._saved_full_action_spec = self.full_action_spec.clone() def _reset(self, tensordict): - data = self.observation_spec.zero() + data = self._saved_obs_spec.zero() data.set_non_tensor("non_tensor", 0) data.update(self.full_done_spec.zero()) return data @@ -1947,10 +1951,10 @@ def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: - data = self.observation_spec.zero() + data = self._saved_obs_spec.zero() data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1) data.update(self.full_done_spec.zero()) - data.update(self.full_reward_spec.zero()) + data.update(self._saved_full_reward_spec.zero()) return data def _set_seed(self, seed: Optional[int]): diff --git a/test/test_env.py b/test/test_env.py index cef7a507f2a..415c973b6fb 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3553,8 +3553,13 @@ def test_single_env_spec(): assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) -def test_auto_spec(): - env = CountingEnv() +@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata]) +def test_auto_spec(env_type): + if env_type is EnvWithMetadata: + obs_vals = ["tensor", "non_tensor"] + else: + obs_vals = "observation" + env = env_type() td = env.reset() policy = lambda td, action_spec=env.full_action_spec.clone(): td.update( @@ -3577,7 +3582,7 @@ def test_auto_spec(): shape=env.full_state_spec.shape, device=env.full_state_spec.device ) env._action_keys = ["action"] - env.auto_specs_(policy, tensordict=td.copy()) + env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals) env.check_env_specs(tensordict=td.copy())