Skip to content

Commit

Permalink
Move custom arange
Browse files Browse the repository at this point in the history
  • Loading branch information
oslumbersh authored and vmoens committed Dec 3, 2024
1 parent 607ebc5 commit d767eb3
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8583,6 +8583,26 @@ def _indent(s):
f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})"
)

def _custom_arange(self, nint, device):
result = torch.arange(
start=0.0,
end=1.0,
step=1 / nint,
dtype=self.dtype,
device=device,
)
result_ = result
if self.sampling in (
self.SamplingStrategy.HIGH,
self.SamplingStrategy.MEDIAN,
):
result_ = (1 - result).flip(0)
if self.sampling == self.SamplingStrategy.MEDIAN:
result = (result + result_) / 2
else:
result = result_
return result

def transform_input_spec(self, input_spec):
try:
action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]]
Expand Down Expand Up @@ -8611,29 +8631,11 @@ def transform_input_spec(self, input_spec):
num_intervals = int(num_intervals.squeeze())
self.num_intervals = torch.as_tensor(num_intervals)

def custom_arange(nint):
result = torch.arange(
start=0.0,
end=1.0,
step=1 / nint,
dtype=self.dtype,
device=action_spec.device,
)
result_ = result
if self.sampling in (
self.SamplingStrategy.HIGH,
self.SamplingStrategy.MEDIAN,
):
result_ = (1 - result).flip(0)
if self.sampling == self.SamplingStrategy.MEDIAN:
result = (result + result_) / 2
else:
result = result_
return result

if isinstance(num_intervals, int):
arange = (
custom_arange(num_intervals).expand((*n_act, num_intervals))
self._custom_arange(num_intervals, action_spec.device).expand(
(*n_act, num_intervals)
)
* interval
)
low = action_spec.low
Expand All @@ -8642,7 +8644,7 @@ def custom_arange(nint):
self.register_buffer("intervals", low + arange)
else:
arange = [
custom_arange(_num_intervals) * interval
self._custom_arange(_num_intervals, action_spec.device) * interval
for _num_intervals, interval in zip(
num_intervals.tolist(), interval.unbind(-2)
)
Expand Down

0 comments on commit d767eb3

Please sign in to comment.