Skip to content

Commit

Permalink
Merge branch 'variant/kraken' of https://github.com/andrearosasco/octo
Browse files Browse the repository at this point in the history
…into variant/kraken
  • Loading branch information
andrearosasco committed May 31, 2024
2 parents 2e7b277 + dc6801a commit 07cc35f
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 1 deletion.
156 changes: 156 additions & 0 deletions octo/model/octo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,162 @@ def sample_actions(
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")
return action

@partial(jax.jit, static_argnames=("train", "sample_shape", "argmax", "beam"))
def sample_future_actions(
self,
observations: Data,
tasks: Data,
unnormalization_statistics: Optional[Data] = None,
normalization_type: NormalizationType = NormalizationType.NORMAL,
beam:int = 1,
pad_mask: Optional[ArrayLike] = None,
train: bool = False,
argmax: bool = False,
sample_shape: Tuple[int, ...] = (),
rng: Optional[PRNGKey] = None,
temperature: float = 1.0,
):
"""Samples actions from the model. See `action_heads.py` for more info.
Args:
observations: dictionary of arrays of shape (batch_size, window_size, *)
tasks: dict of tasks of shape (batch_size, *)
unnormalization_statistics: dict of statistics for unnormalizing actions (must contain "mean",
"std", and optionally "mask")
normalization_type: type of normalization applied to the actions
timestep_pad_mask: (batch_size, window_size) Boolean mask that is False when the timestep corresponds to padding
train: whether to run in train mode
...see `action_heads.py` for the rest of the kwargs.
Returns:
actions: (*sample_shape, batch_size, action_horizon, action_dim)
"""
if timestep_pad_mask is None:
timestep_pad_mask = observations["timestep_pad_mask"]

transformer_outputs = self.run_transformer(
observations, tasks, timestep_pad_mask, train=train
)
action_head = self.module.bind({"params": self.params}).heads[
"action"
]

action_logits = action_head(transformer_outputs, train=train)[:, -1]

action_distribution = jax.nn.softmax(action_logits, axis=-1)

action_tokens = jnp.argsort(action_distribution, axis=-1)[..., -beam:].astype(jnp.int32)
confidence = jnp.take_along_axis(action_distribution, action_tokens, axis=-1)

action_tokens = jnp.broadcast_to(
action_tokens, sample_shape + action_tokens.shape
)

action = action_head.action_tokenizer.decode(action_tokens)

if unnormalization_statistics is not None:
if normalization_type == NormalizationType.NORMAL:
mask = unnormalization_statistics.get(
"mask",
jnp.ones_like(unnormalization_statistics["mean"], dtype=bool),
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action * unnormalization_statistics["std"])
+ unnormalization_statistics["mean"],
action,
)
elif normalization_type == NormalizationType.BOUNDS:
mask = unnormalization_statistics.get(
"mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool)
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action + 1)
* (
unnormalization_statistics["p99"]
- unnormalization_statistics["p01"]
)
/ 2
+ unnormalization_statistics["p01"],
action,
)
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")

return action, confidence

@partial(jax.jit, static_argnames=("train", "sample_shape", "beam"))
def sample_trajectory(
self,
observations: Data,
next_action,
tasks: Data,
unnormalization_statistics: Optional[Data] = None,
normalization_type: NormalizationType = NormalizationType.NORMAL,
beam: int = 1,
pad_mask: Optional[ArrayLike] = None,
train: bool = False,
argmax: bool = False,
sample_shape: Tuple[int, ...] = (),
rng: Optional[PRNGKey] = None,
temperature: float = 1.0,
):
if pad_mask is None:
pad_mask = observations["pad_mask"]

transformer_outputs = self.run_transformer(
observations, tasks, pad_mask, train=train
)

trajectory_head = self.module.bind({"params": self.params}).heads[
"trajectory"
]

action = trajectory_head.predict_action(
transformer_outputs,
train=train,
argmax=argmax,
sample_shape=sample_shape,
rng=rng,
temperature=temperature,
)

if unnormalization_statistics is not None:
if normalization_type == NormalizationType.NORMAL:
mask = unnormalization_statistics.get(
"mask",
jnp.ones_like(unnormalization_statistics["mean"], dtype=bool),
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action * unnormalization_statistics["std"])
+ unnormalization_statistics["mean"],
action,
)
elif normalization_type == NormalizationType.BOUNDS:
mask = unnormalization_statistics.get(
"mask", jnp.ones_like(unnormalization_statistics["p01"], dtype=bool)
)
action = action[..., : len(mask)]
action = jnp.where(
mask,
(action + 1)
* (
unnormalization_statistics["p99"]
- unnormalization_statistics["p01"]
)
/ 2
+ unnormalization_statistics["p01"],
action,
)
else:
raise ValueError(f"Unknown normalization type: {normalization_type}")

return action

@classmethod
def load_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion octo/model/octo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __call__(
obs_tokens += self._create_positional_embedding(group_name, obs_tokens)

# Update mask to account for which timesteps are padding
obs_pad_mask = jnp.logical_and(pad_mask[:, :, None], tokenizer_output.mask)
obs_pad_mask = jnp.logical_and(timestep_pad_mask[:, :, None], tokenizer_output.mask)

all_timestep_groups.append(
TimestepGroup(
Expand Down

0 comments on commit 07cc35f

Please sign in to comment.