Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Hash transform #2648

Open
wants to merge 1 commit into
base: gh/kurtamohler/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
FrameSkipTransform,
GrayScale,
gSDENoise,
Hash,
InitTracker,
MultiStepTransform,
NoopResetEnv,
Expand Down Expand Up @@ -2177,6 +2178,158 @@ def test_transform_no_env(self, device, batch):
pytest.skip("TrajCounter cannot be called without env")


# TODO: Add tests that hash NonTensorStacks of strings
class TestHash(TransformBase):
@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_transform_no_env(self, datatype):
if datatype == "tensor":
obs = torch.tensor(10)
elif datatype == "str":
obs = "abcdefg"
else:
raise RuntimeError(f"please add a test case for datatype {datatype}")

td = TensorDict(
{
"observation": obs,
}
)
t = Hash(in_keys=["observation"], out_keys=["hash"])
td_hashed = t(td)

assert td_hashed["observation"] is td["observation"]
assert td_hashed["hash"] == hash(td["observation"])

def test_single_trans_env_check(self):
t = Hash(in_keys=["observation"], out_keys=["hash"])
env = TransformedEnv(CountingEnv(), t)
check_env_specs(env)

def test_serial_trans_env_check(self):
def make_env():
t = Hash(
in_keys=["observation"],
out_keys=["hash"],
)
return TransformedEnv(CountingEnv(), t)

env = SerialEnv(2, make_env)
check_env_specs(env)

def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv):
def make_env():
t = Hash(in_keys=["observation"], out_keys=["hash"])
return TransformedEnv(CountingEnv(), t)

env = maybe_fork_ParallelEnv(2, make_env)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

def test_trans_serial_env_check(self):
t = Hash(
in_keys=["observation"],
out_keys=["hash"],
)

env = TransformedEnv(SerialEnv(2, CountingEnv), t)
check_env_specs(env)

def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
t = Hash(
in_keys=["observation"],
out_keys=["hash"],
)

env = TransformedEnv(maybe_fork_ParallelEnv(2, CountingEnv), t)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_transform_compose(self, datatype):
if datatype == "tensor":
obs = torch.tensor(10)
elif datatype == "str":
obs = "abcdefg"
else:
raise RuntimeError(f"please add a test case for datatype {datatype}")

td = TensorDict(
{
"observation": obs,
}
)
t = Hash(in_keys=["observation"], out_keys=["hash"])
t = Compose(t)
td_hashed = t(td)

assert td_hashed["observation"] is td["observation"]
assert td_hashed["hash"] == hash(td["observation"])

def test_transform_model(self):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hash"), ("hash",)],
)
model = nn.Sequential(t, nn.Identity())
td = TensorDict(
{("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, []
)
td_out = model(td)
assert ("next", "hash") in td_out.keys(True)
assert ("hash",) in td_out.keys(True)
assert td_out["next", "hash"] == hash(td["next", "observation"])
assert td_out["hash"] == hash(td["observation"])

@pytest.mark.skipif(not _has_gym, reason="Gym not found")
def test_transform_env(self):
t = Hash(
in_keys=["observation"],
out_keys=["hash"],
)
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
assert env.observation_spec["hash"]
assert "observation" in env.observation_spec
assert "observation" in env.base_env.observation_spec
check_env_specs(env)

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hash"), ("hash",)],
)
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(t)
td = TensorDict(
{
"observation": torch.randn(3, 4),
"next": TensorDict(
{"observation": torch.randn(3, 4)},
[],
),
},
[],
).expand(10)
rb.extend(td)
td = rb.sample(2)
assert "observation_out" in td.keys()
assert "observation" not in td.keys()
assert ("next", "observation") not in td.keys(True)

def test_transform_inverse(self):
raise pytest.skip("No inverse for Hash")


class TestStack(TransformBase):
def test_single_trans_env_check(self):
t = Stack(
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
FrameSkipTransform,
GrayScale,
gSDENoise,
Hash,
InitTracker,
KLRewardTransform,
MultiStepTransform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FrameSkipTransform,
GrayScale,
gSDENoise,
Hash,
InitTracker,
NoopResetEnv,
ObservationNorm,
Expand Down
46 changes: 46 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4400,6 +4400,52 @@ def __repr__(self) -> str:
)


class Hash(Transform):
"""Adds a hash value to a tensordict.

Args:
in_keys (sequence of NestedKey): the key of the data to create the hash from.
out_key (sequence of NestedKey): the key of the resulting hash.
"""

def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey],
):
super().__init__(in_keys=in_keys, out_keys=out_keys)

# TODO: If this transform is run on a tensordict like
# `TensorDict({"obs": # tensor.rand(2)}, batch_size=[2])`, then
# `_apply_transform` will create only one hash value for the tensor of size
# 2. Then, when `forward` tries to add the hash to the tensordict, an error
# is raised since the hash doesn't have a leading dimension of size 2.
# TODO: Add support for NonTensorStack inputs.
def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
if isinstance(observation, NonTensorData):
obs = observation.get("data")
else:
obs = observation
return hash(obs)

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if not isinstance(observation_spec, Composite):
raise TypeError(f"{self}: Only specs of type Composite can be transformed")
for out_key in self.out_keys:
observation_spec.set(
out_key,
Unbounded(shape=(), dtype=torch.int64),
)
return observation_spec


class Stack(Transform):
"""Stacks tensors and tensordicts.

Expand Down
Loading