Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor trees #2634

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 101 additions & 3 deletions test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def _state0(self) -> TensorDict:
def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict:
done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1)
reward = action.clone()
action = action + torch.arange(action.shape[-1]) / action.shape[-1]

return TensorDict(
{
Expand All @@ -326,7 +327,7 @@ def _make_forest(self) -> MCTSForest:
forest.extend(r4)
return forest

def _make_forest_intersect(self) -> MCTSForest:
def _make_forest_rebranching(self) -> MCTSForest:
"""
├── 0
│ ├── 16
Expand Down Expand Up @@ -449,7 +450,7 @@ def test_forest_check_ids(self):

def test_forest_intersect(self):
state0 = self._state0()
forest = self._make_forest_intersect()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0)
subtree = forest.get_tree(TensorDict(observation=19))

Expand All @@ -467,13 +468,110 @@ def test_forest_intersect(self):

def test_forest_intersect_vertices(self):
state0 = self._state0()
forest = self._make_forest_intersect()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0)
assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash"))
assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash"))
with pytest.raises(ValueError, match="key_type must be"):
tree.vertices(key_type="another key type")

@pytest.mark.skipif(not _has_gym, reason="requires gym")
def test_simple_tree(self):
from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")
r = env.rollout(10)
state0 = r[0]
forest = MCTSForest()
forest.extend(r)
# forest = self._make_forest_intersect()
tree = forest.get_tree(state0, compact=False)
assert tree.max_length() == 9
for p in tree.valid_paths():
assert len(p) == 9

@pytest.mark.parametrize(
"tree_type,compact",
[
["simple", False],
["forest", False],
# parent of rebranching trees are still buggy
# ["rebranching", False],
# ["rebranching", True],
],
)
def test_forest_parent(self, tree_type, compact):
if tree_type == "simple":
if not _has_gym:
pytest.skip("requires gym")
from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")
r = env.rollout(10)
state0 = r[0]
forest = MCTSForest()
forest.extend(r)
tree = forest.get_tree(state0, compact=compact)
elif tree_type == "forest":
state0 = self._state0()
forest = self._make_forest()
tree = forest.get_tree(state0, compact=compact)
else:
state0 = self._state0()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0, compact=compact)
# Check access
tree.subtree.parent
tree.subtree.subtree.parent
tree.subtree.subtree.subtree.parent

# check present of weakref
assert tree.subtree[0]._parent is not None
assert tree.subtree[0].subtree[0]._parent is not None

# Check content
assert_close(tree.subtree.parent, tree)
for p in tree.valid_paths():
root = tree
for it in p:
node = root.subtree[it]
assert_close(node.parent, root)
root = node

def test_forest_action_attr(self):
state0 = self._state0()
forest = self._make_forest()
tree = forest.get_tree(state0)
assert tree.branching_action is None
assert (tree.subtree.branching_action != tree.subtree.prev_action).any()
assert (
tree.subtree[0].subtree.branching_action
!= tree.subtree[0].subtree.prev_action
).any()
assert tree.prev_action is None

@pytest.mark.parametrize("intersect", [False, True])
def test_forest_check_obs_match(self, intersect):
state0 = self._state0()
if intersect:
forest = self._make_forest_rebranching()
else:
forest = self._make_forest()
tree = forest.get_tree(state0)
for path in tree.valid_paths():
prev_tree = tree
for p in path:
subtree = prev_tree.subtree[p]
assert (
subtree.node_data["observation"]
== subtree.rollout[..., -1]["next", "observation"]
).all()
assert (
subtree.node_observation
== subtree.rollout[..., -1]["next", "observation"]
).all()
prev_tree = subtree


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
3 changes: 2 additions & 1 deletion torchrl/data/map/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
class SipHash(Module):
"""A Module to Compute SipHash values for given tensors.

A hash function module based on SipHash implementation in python.
A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
and the output shape will be ``[batch_size]``.

Args:
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
Expand Down
32 changes: 28 additions & 4 deletions torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def __init__(
self.collate_fn = collate_fn
self.write_fn = write_fn

@property
def max_size(self):
return self.storage.max_size

@property
def out_keys(self) -> List[NestedKey]:
out_keys = self.__dict__.get("_out_keys_and_lazy")
Expand Down Expand Up @@ -177,7 +181,7 @@ def from_tensordict_pair(
collate_fn: Callable[[Any], Any] | None = None,
write_fn: Callable[[Any, Any], Any] | None = None,
consolidated: bool | None = None,
):
) -> TensorDictMap:
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.

Args:
Expand Down Expand Up @@ -238,7 +242,13 @@ def from_tensordict_pair(
n_feat = 0
hash_module = []
for in_key in in_keys:
n_feat = source[in_key].shape[-1]
entry = source[in_key]
if entry.ndim == source.ndim:
# this is a good example of why td/tc are useful - carrying metadata
# allows us to know if there's a feature dim or not
n_feat = 0
else:
n_feat = entry.shape[-1]
if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT:
_hash_module = RandomProjectionHash()
else:
Expand Down Expand Up @@ -308,7 +318,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
if not self._has_lazy_out_keys():
# TODO: make this work with pytrees and avoid calling select if keys match
value = value.select(*self.out_keys, strict=False)
item, value = self._maybe_add_batch(item, value)
index = self._to_index(item, extend=True)
if index.unique().numel() < index.numel():
# If multiple values point to the same place in the storage, we cannot process them by batch
# There could be a better way to deal with this, using unique ids.
vals = []
for it, val in zip(item.split(1), value.split(1)):
self[it] = val
vals.append(val)
# __setitem__ may affect the content of the input data
value.update(TensorDictBase.lazy_stack(vals))
return
if self.write_fn is not None:
# We use this block in the following context: the value written in the storage is already present,
# but it needs to be updated.
# We first check if the value is already there using `contains`. If so, we pass the new value and the
# previous one to write_fn. The values that are not present are passed alone.
if len(self):
modifiable = self.contains(item)
if modifiable.any():
Expand All @@ -322,8 +348,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
value = self.write_fn(value)
else:
value = self.write_fn(value)
item, value = self._maybe_add_batch(item, value)
index = self._to_index(item, extend=True)
self.storage.set(index, value)

def __len__(self):
Expand Down
Loading
Loading