diff --git a/octo/model/octo_model.py b/octo/model/octo_model.py index 6841b396..1cd261c2 100644 --- a/octo/model/octo_model.py +++ b/octo/model/octo_model.py @@ -250,6 +250,162 @@ def sample_actions( else: raise ValueError(f"Unknown normalization type: {normalization_type}") return action + + @partial(jax.jit, static_argnames=("train", "sample_shape", "argmax", "beam")) + def sample_future_actions( + self, + observations: Data, + tasks: Data, + unnormalization_statistics: Optional[Data] = None, + normalization_type: NormalizationType = NormalizationType.NORMAL, + beam:int = 1, + pad_mask: Optional[ArrayLike] = None, + train: bool = False, + argmax: bool = False, + sample_shape: Tuple[int, ...] = (), + rng: Optional[PRNGKey] = None, + temperature: float = 1.0, + ): + """Samples actions from the model. See `action_heads.py` for more info. + + Args: + observations: dictionary of arrays of shape (batch_size, window_size, *) + tasks: dict of tasks of shape (batch_size, *) + unnormalization_statistics: dict of statistics for unnormalizing actions (must contain "mean", + "std", and optionally "mask") + normalization_type: type of normalization applied to the actions + timestep_pad_mask: (batch_size, window_size) Boolean mask that is False when the timestep corresponds to padding + train: whether to run in train mode + ...see `action_heads.py` for the rest of the kwargs. + Returns: + actions: (*sample_shape, batch_size, action_horizon, action_dim) + """ + if timestep_pad_mask is None: + timestep_pad_mask = observations["timestep_pad_mask"] + + transformer_outputs = self.run_transformer( + observations, tasks, timestep_pad_mask, train=train + ) + action_head = self.module.bind({"params": self.params}).heads[ + "action" + ] + + action_logits = action_head(transformer_outputs, train=train)[:, -1] + + action_distribution = jax.nn.softmax(action_logits, axis=-1) + + action_tokens = jnp.argsort(action_distribution, axis=-1)[..., -beam:].astype(jnp.int32) + confidence = jnp.take_along_axis(action_distribution, action_tokens, axis=-1) + + action_tokens = jnp.broadcast_to( + action_tokens, sample_shape + action_tokens.shape + ) + + action = action_head.action_tokenizer.decode(action_tokens) + + if unnormalization_statistics is not None: + if normalization_type == NormalizationType.NORMAL: + mask = unnormalization_statistics.get( + "mask", + jnp.ones_like(unnormalization_statistics["mean"], dtype=bool), + ) + action = action[..., : len(mask)] + action = jnp.where( + mask, + (action * unnormalization_statistics["std"]) + + unnormalization_statistics["mean"], + action, + ) + elif normalization_type == NormalizationType.BOUNDS: + mask = unnormalization_statistics.get( + "mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool) + ) + action = action[..., : len(mask)] + action = jnp.where( + mask, + (action + 1) + * ( + unnormalization_statistics["p99"] + - unnormalization_statistics["p01"] + ) + / 2 + + unnormalization_statistics["p01"], + action, + ) + else: + raise ValueError(f"Unknown normalization type: {normalization_type}") + + return action, confidence + + @partial(jax.jit, static_argnames=("train", "sample_shape", "beam")) + def sample_trajectory( + self, + observations: Data, + next_action, + tasks: Data, + unnormalization_statistics: Optional[Data] = None, + normalization_type: NormalizationType = NormalizationType.NORMAL, + beam: int = 1, + pad_mask: Optional[ArrayLike] = None, + train: bool = False, + argmax: bool = False, + sample_shape: Tuple[int, ...] = (), + rng: Optional[PRNGKey] = None, + temperature: float = 1.0, + ): + if pad_mask is None: + pad_mask = observations["pad_mask"] + + transformer_outputs = self.run_transformer( + observations, tasks, pad_mask, train=train + ) + + trajectory_head = self.module.bind({"params": self.params}).heads[ + "trajectory" + ] + + action = trajectory_head.predict_action( + transformer_outputs, + train=train, + argmax=argmax, + sample_shape=sample_shape, + rng=rng, + temperature=temperature, + ) + + if unnormalization_statistics is not None: + if normalization_type == NormalizationType.NORMAL: + mask = unnormalization_statistics.get( + "mask", + jnp.ones_like(unnormalization_statistics["mean"], dtype=bool), + ) + action = action[..., : len(mask)] + action = jnp.where( + mask, + (action * unnormalization_statistics["std"]) + + unnormalization_statistics["mean"], + action, + ) + elif normalization_type == NormalizationType.BOUNDS: + mask = unnormalization_statistics.get( + "mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool) + ) + action = action[..., : len(mask)] + action = jnp.where( + mask, + (action + 1) + * ( + unnormalization_statistics["p99"] + - unnormalization_statistics["p01"] + ) + / 2 + + unnormalization_statistics["p01"], + action, + ) + else: + raise ValueError(f"Unknown normalization type: {normalization_type}") + + return action @classmethod def load_pretrained( diff --git a/octo/model/octo_module.py b/octo/model/octo_module.py index 921d2196..7abf9b47 100644 --- a/octo/model/octo_module.py +++ b/octo/model/octo_module.py @@ -258,7 +258,7 @@ def __call__( obs_tokens += self._create_positional_embedding(group_name, obs_tokens) # Update mask to account for which timesteps are padding - obs_pad_mask = jnp.logical_and(pad_mask[:, :, None], tokenizer_output.mask) + obs_pad_mask = jnp.logical_and(timestep_pad_mask[:, :, None], tokenizer_output.mask) all_timestep_groups.append( TimestepGroup(