Skip to content

Commit

Permalink
Add ability to set references with numpy unicode strings. This allows…
Browse files Browse the repository at this point in the history
… us directly to use `distributions.UniformChoice(material_names)` to randomize over materials of a geom.

PiperOrigin-RevId: 690640836
Change-Id: I140fa86b22d5a0571fc8777f6c222dceddecaf81
  • Loading branch information
Leonard Hasenclever authored and copybara-github committed Oct 28, 2024
1 parent e3f93b0 commit 8eb70f8
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion dm_control/mjcf/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ def reference_namespace(self):
return self._reference_namespace

def _assign(self, value):
if not isinstance(value, (base.Element, str)):
if not isinstance(value, (base.Element, str)) and not (
isinstance(value, np.ndarray) and value.dtype.kind == 'U'
):
raise ValueError(
'Expect a string or `mjcf.Element` value: got {}'.format(value))
elif not value:
Expand All @@ -330,6 +332,8 @@ def _assign(self, value):
raise ValueError(_INVALID_REFERENCE_TYPE.format(
valid_type=self.reference_namespace,
actual_type=value_namespace))
elif isinstance(value, np.ndarray):
value = value.item()
self._value = value

def _before_clear(self):
Expand Down

0 comments on commit 8eb70f8

Please sign in to comment.