diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7ab5a2deb72..980273af96c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8583,6 +8583,26 @@ def _indent(s): f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})" ) + def _custom_arange(self, nint, device): + result = torch.arange( + start=0.0, + end=1.0, + step=1 / nint, + dtype=self.dtype, + device=device, + ) + result_ = result + if self.sampling in ( + self.SamplingStrategy.HIGH, + self.SamplingStrategy.MEDIAN, + ): + result_ = (1 - result).flip(0) + if self.sampling == self.SamplingStrategy.MEDIAN: + result = (result + result_) / 2 + else: + result = result_ + return result + def transform_input_spec(self, input_spec): try: action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]] @@ -8611,29 +8631,11 @@ def transform_input_spec(self, input_spec): num_intervals = int(num_intervals.squeeze()) self.num_intervals = torch.as_tensor(num_intervals) - def custom_arange(nint): - result = torch.arange( - start=0.0, - end=1.0, - step=1 / nint, - dtype=self.dtype, - device=action_spec.device, - ) - result_ = result - if self.sampling in ( - self.SamplingStrategy.HIGH, - self.SamplingStrategy.MEDIAN, - ): - result_ = (1 - result).flip(0) - if self.sampling == self.SamplingStrategy.MEDIAN: - result = (result + result_) / 2 - else: - result = result_ - return result - if isinstance(num_intervals, int): arange = ( - custom_arange(num_intervals).expand((*n_act, num_intervals)) + self._custom_arange(num_intervals, action_spec.device).expand( + (*n_act, num_intervals) + ) * interval ) low = action_spec.low @@ -8642,7 +8644,7 @@ def custom_arange(nint): self.register_buffer("intervals", low + arange) else: arange = [ - custom_arange(_num_intervals) * interval + self._custom_arange(_num_intervals, action_spec.device) * interval for _num_intervals, interval in zip( num_intervals.tolist(), interval.unbind(-2) )