Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] spec.cardinality #2638

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,85 @@ def test_unboundeddiscrete(
assert spec is not spec.clone()


class TestCardinality:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1):
spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool)
assert spec.cardinality() == len(list(spec.enumerate()))

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_discrete(
self,
shape1,
):
spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long)
assert spec.cardinality() == len(list(spec.enumerate()))

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_multidiscrete(
self,
shape1,
):
if shape1 is None:
shape1 = (3,)
else:
shape1 = (*shape1, 3)
spec = MultiCategorical(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec.cardinality() == len(spec.enumerate())

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_multionehot(
self,
shape1,
):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long)
assert spec.cardinality() == len(list(spec.enumerate()))

def test_non_tensor(self):
spec = NonTensor(shape=(3, 4), device="cpu")
with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
spec.cardinality()

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
self,
shape1,
):
if shape1 is None:
shape1 = (15,)
else:
shape1 = (*shape1, 15)
spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long)
assert spec.cardinality() == len(list(spec.enumerate()))

def test_composite(self):
batch_size = (5,)
spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool)
spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long)
spec4 = MultiCategorical(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
spec5 = MultiOneHot(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long)
spec = Composite(
spec2=spec2,
spec3=spec3,
spec4=spec4,
spec5=spec5,
spec6=spec6,
shape=batch_size,
)
assert spec.cardinality() == len(spec.enumerate())


class TestUnbind:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1):
Expand Down
124 changes: 109 additions & 15 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
unravel_key,
)
from tensordict.base import NO_DEFAULT
from tensordict.utils import _getitem_batch_size, NestedKey
from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey
from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for

DEVICE_TYPING = Union[torch.device, str, int]
Expand Down Expand Up @@ -582,6 +582,16 @@ def clear_device_(self) -> T:
"""
return self

@abc.abstractmethod
def cardinality(self) -> int:
"""The cardinality of the spec.

This refers to the number of possible outcomes in a spec. It is assumed that the cardinality of a composite
spec is the cartesian product of all possible outcomes.

"""
...

def encode(
self,
val: np.ndarray | torch.Tensor | TensorDictBase,
Expand Down Expand Up @@ -1515,6 +1525,9 @@ def __init__(
def n(self):
return self.space.n

def cardinality(self) -> int:
return self.n

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.

Expand Down Expand Up @@ -2107,6 +2120,9 @@ def enumerate(self) -> Any:
f"enumerate is not implemented for spec of class {type(self).__name__}."
)

def cardinality(self) -> int:
return float("inf")

def __eq__(self, other):
return (
type(other) == type(self)
Expand Down Expand Up @@ -2426,8 +2442,11 @@ def __init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)

def cardinality(self) -> Any:
raise RuntimeError("Cannot enumerate a NonTensorSpec.")

def enumerate(self) -> Any:
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
raise RuntimeError("Cannot enumerate a NonTensorSpec.")

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
if isinstance(dest, torch.dtype):
Expand Down Expand Up @@ -2466,10 +2485,10 @@ def one(self, shape=None):
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
)

def is_in(self, val: torch.Tensor) -> bool:
def is_in(self, val: Any) -> bool:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return (
isinstance(val, NonTensorData)
is_non_tensor(val)
and val.shape == shape
# We relax constrains on device as they're hard to enforce for non-tensor
# tensordicts and pointless
Expand Down Expand Up @@ -2832,6 +2851,9 @@ def __init__(
)
self.update_mask(mask)

def cardinality(self) -> int:
return torch.as_tensor(self.nvec).prod()

def enumerate(self) -> torch.Tensor:
nvec = self.nvec
enum_disc = self.to_categorical_spec().enumerate()
Expand Down Expand Up @@ -3220,13 +3242,20 @@ class Categorical(TensorSpec):
The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is
desired for the training dimension, one should specify it explicitly.

Attributes:
n (int): The number of possible outcomes.
shape (torch.Size): The shape of the variable.
device (torch.device): The device of the tensors.
dtype (torch.dtype): The dtype of the tensors.

Args:
n (int): number of possible outcomes.
n (int): number of possible outcomes. If set to -1, the cardinality of the categorical spec is undefined,
and `set_provisional_n` must be called before sampling from this spec.
shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])".
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.
mask (torch.Tensor or None): mask some of the possible outcomes when a
sample is taken. See :meth:`~.update_mask` for more information.
device (str, int or torch.device, optional): the device of the tensors.
dtype (str or torch.dtype, optional): the dtype of the tensors.
mask (torch.Tensor or None): A boolean mask to prevent some of the possible outcomes when a sample is taken.
See :meth:`~.update_mask` for more information.

Examples:
>>> categ = Categorical(3)
Expand All @@ -3249,6 +3278,13 @@ class Categorical(TensorSpec):
domain=discrete)
>>> categ.rand()
tensor([1])
>>> categ = Categorical(-1)
>>> categ.set_provisional_n(5)
>>> categ.rand()
tensor(3)

.. note:: When n is set to -1, calling `rand` without first setting a provisional n using `set_provisional_n`
will raise a ``RuntimeError``.

"""

Expand Down Expand Up @@ -3276,16 +3312,31 @@ def __init__(
shape=shape, space=space, device=device, dtype=dtype, domain="discrete"
)
self.update_mask(mask)
self._provisional_n = None

def enumerate(self) -> torch.Tensor:
arange = torch.arange(self.n, dtype=self.dtype, device=self.device)
dtype = self.dtype
if dtype is torch.bool:
dtype = torch.uint8
arange = torch.arange(self.n, dtype=dtype, device=self.device)
if self.ndim:
arange = arange.view(-1, *(1,) * self.ndim)
return arange.expand(self.n, *self.shape)

@property
def n(self):
return self.space.n
n = self.space.n
if n == -1:
n = self._provisional_n
if n is None:
raise RuntimeError(
f"Undefined cardinality for {type(self)}. Please call "
f"spec.set_provisional_n(int)."
)
return n

def cardinality(self) -> int:
return self.n

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
Expand Down Expand Up @@ -3316,13 +3367,33 @@ def update_mask(self, mask):
raise ValueError("Only boolean masks are accepted.")
self.mask = mask

def set_provisional_n(self, n: int):
"""Set the cardinality of the Categorical spec temporarily.

This method is required to be called before sampling from the spec when n is -1.

Args:
n (int): The cardinality of the Categorical spec.

"""
self._provisional_n = n

def rand(self, shape: torch.Size = None) -> torch.Tensor:
if self.space.n < 0:
if self._provisional_n is None:
raise RuntimeError(
"Cannot generate random categorical samples for undefined cardinality (n=-1). "
"To sample from this class, first call Categorical.set_provisional_n(n) before calling rand()."
)
n = self._provisional_n
else:
n = self.space.n
if shape is None:
shape = _size([])
if self.mask is None:
return torch.randint(
0,
self.space.n,
n,
_size([*shape, *self.shape]),
device=self.device,
dtype=self.dtype,
Expand All @@ -3334,6 +3405,12 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
else:
mask_flat = mask
shape_out = mask.shape[:-1]
# Check that the mask has the right size
if mask_flat.shape[-1] != n:
raise ValueError(
"The last dimension of the mask must match the number of action allowed by the "
f"Categorical spec. Got mask.shape={self.mask.shape} and n={n}."
)
out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out)
return out

Expand All @@ -3360,6 +3437,8 @@ def is_in(self, val: torch.Tensor) -> bool:
dtype_match = val.dtype == self.dtype
if not dtype_match:
return False
if self.space.n == -1:
return True
return (0 <= val).all() and (val < self.space.n).all()
shape = self.mask.shape
shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
Expand Down Expand Up @@ -3607,7 +3686,7 @@ def __init__(
device: Optional[DEVICE_TYPING] = None,
dtype: Union[str, torch.dtype] = torch.int8,
):
if n is None and not shape:
if n is None and shape is None:
raise TypeError("Must provide either n or shape.")
if n is None:
n = shape[-1]
Expand Down Expand Up @@ -3813,6 +3892,9 @@ def enumerate(self) -> torch.Tensor:
arange = arange.expand(arange.shape[0], *self.shape)
return arange

def cardinality(self) -> int:
return self.nvec._base.prod()

def update_mask(self, mask):
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.

Expand Down Expand Up @@ -4373,7 +4455,7 @@ def set(self, name, spec):
shape = spec.shape
if shape[: self.ndim] != self.shape:
if (
isinstance(spec, Composite)
isinstance(spec, (Composite, NonTensor))
and spec.ndim < self.ndim
and self.shape[: spec.ndim] == spec.shape
):
Expand All @@ -4382,7 +4464,7 @@ def set(self, name, spec):
spec.shape = self.shape
else:
raise ValueError(
"The shape of the spec and the Composite mismatch: the first "
f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first "
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
f"Composite.shape={self.shape}."
)
Expand Down Expand Up @@ -4798,6 +4880,18 @@ def clone(self) -> Composite:
shape=self.shape,
)

def cardinality(self) -> int:
n = None
for spec in self.values():
if spec is None:
continue
if n is None:
n = 1
n = n * spec.cardinality()
if n is None:
n = 0
return n

def enumerate(self) -> TensorDictBase:
# We are going to use meshgrid to create samples of all the subspecs in here
# but first let's get rid of the batch size, we'll put it back later
Expand Down
19 changes: 19 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,25 @@ def check_env_specs(self, *args, **kwargs):

check_env_specs.__doc__ = check_env_specs_func.__doc__

def cardinality(self, tensordict: TensorDictBase | None = None) -> int:
"""The cardinality of the action space.

By default, this is just a wrapper around :meth:`env.action_space.cardinality <~torchrl.data.TensorSpec.cardinality>`.

This class is useful when the action spec is variable:

- The number of actions can be undefined, e.g., ``Categorical(n=-1)``;
- The action cardinality may depend on the action mask;
- The shape can be dynamic, as in ``Unbound(shape=(-1))``.

In these cases, the :meth:`~.cardinality` should be overwritten,

Args:
tensordict (TensorDictBase, optional): a tensordict containing the data required to compute the cardinality.

"""
return self.full_action_spec.cardinality()

@classmethod
def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):
# inplace update will write tensors in-place on the provided tensordict.
Expand Down
Loading