diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d78e2f27184..bb902f879b1 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1927,3 +1927,72 @@ def _step( def _set_seed(self, seed: Optional[int]): self.manual_seed = seed return seed + + +class EnvWithScalarAction(EnvBase): + def __init__(self, singleton: bool = False, **kwargs): + super().__init__(**kwargs) + self.singleton = singleton + self.action_spec = Bounded( + -1, + 1, + shape=( + *self.batch_size, + 1, + ) + if self.singleton + else self.batch_size, + ) + self.observation_spec = Composite( + observation=Unbounded( + shape=( + *self.batch_size, + 3, + ) + ), + shape=self.batch_size, + ) + self.done_spec = Composite( + done=Unbounded(self.batch_size + (1,), dtype=torch.bool), + terminated=Unbounded(self.batch_size + (1,), dtype=torch.bool), + truncated=Unbounded(self.batch_size + (1,), dtype=torch.bool), + shape=self.batch_size, + ) + self.reward_spec = Unbounded( + shape=( + *self.batch_size, + 1, + ) + ) + + def _reset(self, td: TensorDict): + return TensorDict( + observation=torch.randn(*self.batch_size, 3, device=self.device), + done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device), + truncated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + terminated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + device=self.device, + ) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + return TensorDict( + observation=torch.randn(*self.batch_size, 3, device=self.device), + reward=torch.zeros(1, device=self.device), + done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device), + truncated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + terminated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + ) + + def _set_seed(self, seed: Optional[int]): + ... diff --git a/test/test_transforms.py b/test/test_transforms.py index 8b2ada8c93a..ae428d35d97 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -41,6 +41,7 @@ CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, + EnvWithScalarAction, IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, @@ -66,6 +67,7 @@ CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, + EnvWithScalarAction, IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, @@ -11781,17 +11783,33 @@ def test_transform_inverse(self): class TestActionDiscretizer(TransformBase): @pytest.mark.parametrize("categorical", [True, False]) - def test_single_trans_env_check(self, categorical): - base_env = ContinuousActionVecMockEnv() + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_single_trans_env_check(self, categorical, env_cls): + base_env = env_cls() env = base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_serial_trans_env_check(self, categorical): + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_serial_trans_env_check(self, categorical, env_cls): def make_env(): - base_env = ContinuousActionVecMockEnv() + base_env = env_cls() return base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) @@ -11800,9 +11818,17 @@ def make_env(): check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_parallel_trans_env_check(self, categorical): + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_parallel_trans_env_check(self, categorical, env_cls): def make_env(): - base_env = ContinuousActionVecMockEnv() + base_env = env_cls() env = base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) @@ -11812,17 +11838,33 @@ def make_env(): check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_trans_serial_env_check(self, categorical): - env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform( + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_trans_serial_env_check(self, categorical, env_cls): + env = SerialEnv(2, env_cls).append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_trans_parallel_env_check(self, categorical): - env = ParallelEnv( - 2, ContinuousActionVecMockEnv, mp_start_method=mp_ctx - ).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical)) + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_trans_parallel_env_check(self, categorical, env_cls): + env = ParallelEnv(2, env_cls, mp_start_method=mp_ctx).append_transform( + ActionDiscretizer(num_intervals=5, categorical=categorical) + ) check_env_specs(env) def test_transform_no_env(self): @@ -11838,7 +11880,6 @@ def test_transform_compose(self): check_env_specs(env) @pytest.mark.skipif(not _has_gym, reason="gym required for this test") - @pytest.mark.parametrize("envname", ["cheetah", "pendulum"]) @pytest.mark.parametrize("interval_as_tensor", [False, True]) @pytest.mark.parametrize("categorical", [True, False]) @pytest.mark.parametrize( @@ -11851,15 +11892,37 @@ def test_transform_compose(self): ActionDiscretizer.SamplingStrategy.RANDOM, ], ) - def test_transform_env(self, envname, interval_as_tensor, categorical, sampling): + @pytest.mark.parametrize( + "env_cls", + [ + "cheetah", + "pendulum", + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_transform_env(self, env_cls, interval_as_tensor, categorical, sampling): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - base_env = GymEnv( - HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED(), - device=device, - ) - if interval_as_tensor: - num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6) + if env_cls == "cheetah": + base_env = GymEnv( + HALFCHEETAH_VERSIONED(), + device=device, + ) + num_intervals = torch.arange(5, 11) + elif env_cls == "pendulum": + base_env = GymEnv( + PENDULUM_VERSIONED(), + device=device, + ) + num_intervals = torch.arange(5, 6) else: + base_env = env_cls( + device=device, + ) + num_intervals = torch.arange(5, 6) + + if not interval_as_tensor: + # override num_intervals = 5 t = ActionDiscretizer( num_intervals=num_intervals, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7bdd25591cd..7ab5a2deb72 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8585,24 +8585,32 @@ def _indent(s): def transform_input_spec(self, input_spec): try: - action_spec = input_spec["full_action_spec", self.in_keys_inv[0]] + action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]] if not isinstance(action_spec, Bounded): raise TypeError( - f"action spec type {type(action_spec)} is not supported." + f"action spec type {type(action_spec)} is not supported. The action spec type must be Bounded." ) n_act = action_spec.shape if not n_act: - n_act = 1 + n_act = () + empty_shape = True else: - n_act = n_act[-1] + n_act = (n_act[-1],) + empty_shape = False self.n_act = n_act self.dtype = action_spec.dtype - interval = (action_spec.high - action_spec.low).unsqueeze(-1) + interval = action_spec.high - action_spec.low num_intervals = self.num_intervals + if not empty_shape: + interval = interval.unsqueeze(-1) + elif isinstance(num_intervals, torch.Tensor): + num_intervals = int(num_intervals.squeeze()) + self.num_intervals = torch.as_tensor(num_intervals) + def custom_arange(nint): result = torch.arange( start=0.0, @@ -8625,11 +8633,13 @@ def custom_arange(nint): if isinstance(num_intervals, int): arange = ( - custom_arange(num_intervals).expand(n_act, num_intervals) * interval - ) - self.register_buffer( - "intervals", action_spec.low.unsqueeze(-1) + arange + custom_arange(num_intervals).expand((*n_act, num_intervals)) + * interval ) + low = action_spec.low + if not empty_shape: + low = low.unsqueeze(-1) + self.register_buffer("intervals", low + arange) else: arange = [ custom_arange(_num_intervals) * interval @@ -8644,12 +8654,6 @@ def custom_arange(nint): ) ] - cls = ( - functools.partial(MultiCategorical, remove_singleton=False) - if self.categorical - else MultiOneHot - ) - if not isinstance(num_intervals, torch.Tensor): nvec = torch.as_tensor(num_intervals, device=action_spec.device) else: @@ -8657,7 +8661,10 @@ def custom_arange(nint): if nvec.ndim > 1: raise RuntimeError(f"Cannot use num_intervals with shape {nvec.shape}") if nvec.ndim == 0 or nvec.numel() == 1: - nvec = nvec.expand(action_spec.shape[-1]) + if not empty_shape: + nvec = nvec.expand(action_spec.shape[-1]) + else: + nvec = nvec.squeeze() self.register_buffer("nvec", nvec) if self.sampling == self.SamplingStrategy.RANDOM: # compute jitters @@ -8667,7 +8674,22 @@ def custom_arange(nint): if self.categorical else (*action_spec.shape[:-1], nvec.sum()) ) - action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device) + + if not empty_shape: + cls = ( + functools.partial(MultiCategorical, remove_singleton=False) + if self.categorical + else MultiOneHot + ) + action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device) + + else: + cls = Categorical if self.categorical else OneHot + action_spec = cls(n=int(nvec), shape=shape, device=action_spec.device) + + batch_size = self.parent.batch_size + if batch_size: + action_spec = action_spec.expand(batch_size + action_spec.shape) input_spec["full_action_spec", self.out_keys_inv[0]] = action_spec if self.out_keys_inv[0] != self.in_keys_inv[0]: @@ -8705,6 +8727,8 @@ def _inv_call(self, tensordict): if self.categorical: action = action.unsqueeze(-1) if isinstance(intervals, torch.Tensor): + shape = action.shape[: -intervals.ndim] + intervals = intervals.expand(shape + intervals.shape) action = intervals.gather(index=action, dim=-1).squeeze(-1) else: action = torch.stack( @@ -8715,17 +8739,26 @@ def _inv_call(self, tensordict): -1, ) else: - nvec = self.nvec.tolist() - action = action.split(nvec, dim=-1) - if isinstance(intervals, torch.Tensor): - intervals = intervals.unbind(-2) - action = torch.stack( - [ - intervals[action].view(action.shape[:-1]) - for (intervals, action) in zip(intervals, action) - ], - -1, - ) + nvec = self.nvec + empty_shape = not nvec.ndim + if not empty_shape: + nvec = nvec.tolist() + if isinstance(intervals, torch.Tensor): + shape = action.shape[: (-intervals.ndim + 1)] + intervals = intervals.expand(shape + intervals.shape) + intervals = intervals.unbind(-2) + action = action.split(nvec, dim=-1) + action = torch.stack( + [ + intervals[action].view(action.shape[:-1]) + for (intervals, action) in zip(intervals, action) + ], + -1, + ) + else: + shape = action.shape[: -intervals.ndim] + intervals = intervals.expand(shape + intervals.shape) + action = intervals[action].squeeze(-1) if self.sampling == self.SamplingStrategy.RANDOM: action = action + self.jitters * torch.rand_like(self.jitters)