Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update with many small changes/fixes #92

Merged
merged 16 commits into from
May 23, 2024
35 changes: 22 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Octo
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z0vELj_lX9OWeoMG_WvXnQs43aPOEAhz?usp=sharing)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/octo-models/octo/blob/main/examples/01_inference_pretrained.ipynb)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://octo-models.github.io/)
![](https://github.com/rail-berkeley/octo/workflows/run-debug/badge.svg)
Expand All @@ -15,7 +15,7 @@ for an inference example.

```python
from octo.model.octo_model import OctoModel
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base")
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")
print(model.get_pretty_spec())
```

Expand Down Expand Up @@ -48,7 +48,7 @@ See the [Jax Github page](https://github.com/google/jax) for more details on ins

Test the installation by finetuning on the debug dataset:
```bash
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small-1.5 --debug
```

## Checkpoints
Expand Down Expand Up @@ -99,7 +99,7 @@ We provide a [minimal example](examples/02_finetune_new_observation_action.py) f
We also provide a more advanced finetuning script that allows you to change hyperparameters via a config file and logs finetuning
metrics. To run advanced finetuning, use:
```bash
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small-1.5
```

We offer three finetuning modes depending on the parts of the model that are kept frozen: ```head_only```, ```head_mlp_only```, and ```full``` to finetune the full model.
Expand All @@ -114,9 +114,9 @@ Loading and running a trained Octo model is as easy as:
```python
from octo.model import OctoModel

model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small")
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
task = model.create_tasks(texts=["pick up the spoon"])
action = model.sample_action(observation, task, rng=jax.random.PRNGKey(0))
action = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))
```

We provide examples for evaluating Octo [in a simulated Gym environment](examples/03_eval_finetuned.py) as well
Expand All @@ -140,20 +140,29 @@ To evaluate on your own environment, simply wrap it in a Gym interface and follo
| Visualization | [visualization_lib.py](octo/utils/visualization_lib.py) | Utilities for offline qualitative & quantitative eval. |

## FAQ
#### What is the `pad_mask` in the observation dictionary?
The `pad_mask` indicates which observations should be attended to, which is important when using multiple timesteps of observation history. Octo was trained with a history window size of 2, meaning the model can predict an action using both the current observation and the previous observation. However, at the very beginning of the trajectory, there is no previous observation, so we need to set `pad_mask=False` at the corresponding index. If you use Octo with a window size of 1, pad_mask should always just be `[True]`, indicating that the one and only observation in the window should be attended to. Note that if you wrap your robot environment with the `HistoryWrapper` (see [gym_wrappers.py](octo/utils/gym_wrappers.py)), the `pad_mask` key will be added to the observation dictionary for you.
#### What is the `timestep_pad_mask` in the observation dictionary?
The `timestep_pad_mask` indicates which observations should be attended to, which is important when using multiple timesteps of observation history. Octo was trained with a history window size of 2, meaning the model can predict an action using both the current observation and the previous observation. However, at the very beginning of the trajectory, there is no previous observation, so we need to set `timestep_pad_mask=False` at the corresponding index. If you use Octo with a window size of 1, `timestep_pad_mask` should always just be `[True]`, indicating that the one and only observation in the window should be attended to. Note that if you wrap your robot environment with the `HistoryWrapper` (see [gym_wrappers.py](octo/utils/gym_wrappers.py)), the `timestep_pad_mask` key will be added to the observation dictionary for you.
#### What is `pad_mask_dict` in the observation dictionary?
While `pad_mask` indicates which observations should be attended to on a timestep level, `pad_mask_dict` indicates which elements of the observation should be attended to within a single timestep. For example, for datasets without language labels, `pad_mask_dict["language_instruction"]` is set to `False`. For datasets without a wrist camera, `pad_mask_dict["image_wrist"]` is set to `False`. For convenience, if a key is missing from the observation dict, it is equivalent to setting `pad_mask_dict` to `False` for that key.
While `timestep_pad_mask` indicates which observations should be attended to on a timestep level, `pad_mask_dict` indicates which elements of the observation should be attended to within a single timestep. For example, for datasets without language labels, `pad_mask_dict["language_instruction"]` is set to `False`. For datasets without a wrist camera, `pad_mask_dict["image_wrist"]` is set to `False`. For convenience, if a key is missing from the observation dict, it is equivalent to setting `pad_mask_dict` to `False` for that key.
#### Does `model.sample_actions([...])` return the full trajectory to solve a task?
Octo was pretrained with an action chunking size of 4, meaning it predicts the next 4 actions at once. You can choose to execute all these actions before sampling new ones, or only execute the first action before sampling new ones (also known as receding horizon control). You can also do something more advanced like [temporal ensembling](octo/utils/gym_wrappers.py).

## Updates for Version 1.5
- Improved cross-attention between visual and language tokens by repeating language tokens at every timestep in the context window.
- Augmented the language instructions in the data with rephrasings from GPT-3.5.
- Bug fixes:
- Turned off dropout in the diffusion head due to incompatibility with layer norm.
- Fixed an off-by-one error with the attention mask.
- Fixed an issue where different image augmentations did not get fresh random seeds.

## Citation

```
@misc{octo_2023,
@inproceedings{octo_2023,
title={Octo: An Open-Source Generalist Robot Policy},
author = {{Octo Model Team} and Dibya Ghosh and Homer Walke and Karl Pertsch and Kevin Black and Oier Mees and Sudeep Dasari and Joey Hejna and Charles Xu and Jianlan Luo and Tobias Kreiman and {You Liang} Tan and Dorsa Sadigh and Chelsea Finn and Sergey Levine},
howpublished = {\url{https://octo-models.github.io}},
year = {2023},
author = {{Octo Model Team} and Dibya Ghosh and Homer Walke and Karl Pertsch and Kevin Black and Oier Mees and Sudeep Dasari and Joey Hejna and Charles Xu and Jianlan Luo and Tobias Kreiman and {You Liang} Tan and Pannag Sanketi and Quan Vuong and Ted Xiao and Dorsa Sadigh and Chelsea Finn and Sergey Levine},
booktitle = {Proceedings of Robotics: Science and Systems},
address = {Delft, Netherlands},
year = {2024},
}
```
Binary file modified docs/assets/teaser.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
190 changes: 55 additions & 135 deletions examples/01_inference_pretrained.ipynb

Large diffs are not rendered by default.

18 changes: 8 additions & 10 deletions examples/02_finetune_new_observation_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

To run this example, first download and extract the dataset from here: https://rail.eecs.berkeley.edu/datasets/example_sim_data.zip

python examples/02_finetune_new_observation_action.py --pretrained_path=hf://rail-berkeley/octo-small --data_dir=...
python examples/02_finetune_new_observation_action.py --pretrained_path=hf://rail-berkeley/octo-small-1.5 --data_dir=...
"""
from absl import app, flags, logging
import flax
Expand All @@ -15,7 +15,6 @@
import wandb

from octo.data.dataset import make_single_dataset
from octo.data.utils.data_utils import NormalizationType
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
Expand Down Expand Up @@ -70,14 +69,12 @@ def main(_):
name="aloha_sim_cube_scripted_dataset",
data_dir=FLAGS.data_dir,
image_obs_keys={"primary": "top"},
state_obs_keys=["state"],
proprio_obs_key="state",
language_key="language_instruction",
action_proprio_normalization_type=NormalizationType.NORMAL,
absolute_action_mask=[True] * 14,
),
traj_transform_kwargs=dict(
window_size=1,
future_action_window_size=49, # so we get 50 actions for our action chunk
action_horizon=50,
),
frame_transform_kwargs=dict(
resize_size={"primary": (256, 256)},
Expand Down Expand Up @@ -116,10 +113,10 @@ def process_batch(batch):
high=2.0,
obs_keys=["proprio"],
)
# Fully override the old action head with a new one (for smaller changes, you can use update_module_config)
# Fully override the old action head with a new one (for smaller changes, you can use update_config)
config["model"]["heads"]["action"] = ModuleSpec.create(
L1ActionHead,
pred_horizon=50,
action_horizon=50,
action_dim=14,
readout_key="readout_action",
)
Expand Down Expand Up @@ -162,13 +159,14 @@ def loss_fn(params, batch, rng, train=True):
transformer_embeddings = bound_module.octo_transformer(
batch["observation"],
batch["task"],
batch["observation"]["pad_mask"],
batch["observation"]["timestep_pad_mask"],
train=train,
)
action_loss, action_metrics = bound_module.heads["action"].loss(
transformer_embeddings, # Action head knows to pull out the action readout_key
batch["action"],
pad_mask=batch["observation"]["pad_mask"],
batch["observation"]["timestep_pad_mask"],
batch["action_pad_mask"],
train=train,
)
return action_loss, action_metrics
Expand Down
31 changes: 18 additions & 13 deletions examples/03_eval_finetuned.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This script demonstrates how to load and rollout a finetuned Octo model.
We use the Octo model finetuned on ALOHA sim data from the examples/finetune_new_observation_action.py script.
We use the Octo model finetuned on ALOHA sim data from the examples/02_finetune_new_observation_action.py script.

For installing the ALOHA sim environment, clone: https://github.com/tonyzhaozh/act
Then run:
Expand All @@ -15,6 +15,7 @@
cd examples
python3 03_eval_finetuned.py --finetuned_path=<path_to_finetuned_aloha_checkpoint>
"""
from functools import partial
import sys

from absl import app, flags, logging
Expand All @@ -25,10 +26,12 @@

sys.path.append("path/to/your/act")

from envs.aloha_sim_env import AlohaGymEnv # keep this to register ALOHA sim env
# keep this to register ALOHA sim env
from envs.aloha_sim_env import AlohaGymEnv # noqa

from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio
from octo.utils.gym_wrappers import HistoryWrapper, NormalizeProprio, RHCWrapper
from octo.utils.train_callbacks import supply_rng

FLAGS = flags.FLAGS

Expand All @@ -49,27 +52,31 @@ def main(_):
##################################################################################################################
# environment needs to implement standard gym interface + return observations of the following form:
# obs = {
# "image_0": ...
# "image_1": ...
# "image_primary": ...
# }
# it should also implement an env.get_task() function that returns a task dict with goal and/or language instruct.
# task = {
# "language_instruction": "some string"
# "goal": {
# "image_0": ...
# "image_1": ...
# "image_primary": ...
# }
# }
##################################################################################################################
env = gym.make("aloha-sim-cube-v0")

# wrap env to normalize proprio
env = NormalizeProprio(env, model.dataset_statistics)

# add wrappers for history and "receding horizon control", i.e. action chunking
env = HistoryWrapper(env, horizon=1)
env = RHCWrapper(env, exec_horizon=50)

# wrap env to handle action/proprio normalization -- match normalization type to the one used during finetuning
env = UnnormalizeActionProprio(
env, model.dataset_statistics, normalization_type="normal"
# the supply_rng wrapper supplies a new random key to sample_actions every time it's called
policy_fn = supply_rng(
HomerW marked this conversation as resolved.
Show resolved Hide resolved
partial(
model.sample_actions,
unnormalization_statistics=model.dataset_statistics["action"],
),
)

# running rollouts
Expand All @@ -85,9 +92,7 @@ def main(_):
episode_return = 0.0
while len(images) < 400:
# model returns actions of shape [batch, pred_horizon, action_dim] -- remove batch
actions = model.sample_actions(
jax.tree_map(lambda x: x[None], obs), task, rng=jax.random.PRNGKey(0)
)
actions = policy_fn(jax.tree_map(lambda x: x[None], obs), task)
actions = actions[0]

# step env -- info contains full "chunk" of observations for logging
Expand Down
44 changes: 17 additions & 27 deletions examples/04_eval_finetuned_on_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus

from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import (
HistoryWrapper,
TemporalEnsembleWrapper,
UnnormalizeActionProprio,
)
from octo.utils.gym_wrappers import HistoryWrapper, TemporalEnsembleWrapper
from octo.utils.train_callbacks import supply_rng

np.set_printoptions(suppress=True)

Expand All @@ -50,9 +47,10 @@
flags.DEFINE_integer("im_size", None, "Image size", required=True)
flags.DEFINE_string("video_save_path", None, "Path to save video")
flags.DEFINE_integer("num_timesteps", 120, "num timesteps")
flags.DEFINE_integer("horizon", 1, "Observation history length")
flags.DEFINE_integer("pred_horizon", 1, "Length of action sequence from model")
flags.DEFINE_integer("exec_horizon", 1, "Length of action sequence to execute")
flags.DEFINE_integer("window_size", 2, "Observation history length")
flags.DEFINE_integer(
"action_horizon", 4, "Length of action sequence to execute/ensemble"
)


# show image flag
Expand All @@ -64,10 +62,9 @@
Bridge data was collected with non-blocking control and a step duration of 0.2s.
However, we relabel the actions to make it look like the data was collected with
blocking control and we evaluate with blocking control.
We also use a step duration of 0.4s to reduce the jerkiness of the policy.
Be sure to change the step duration back to 0.2 if evaluating with non-blocking control.
Be sure to use a step duration of 0.2 if evaluating with non-blocking control.
"""
STEP_DURATION = 0.4
STEP_DURATION = 0.2
STICKY_GRIPPER_NUM_STEPS = 1
WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]
CAMERA_TOPICS = [{"name": "/blue/image_raw"}]
Expand Down Expand Up @@ -107,16 +104,12 @@ def main(_):
)

# wrap the robot environment
env = UnnormalizeActionProprio(
env, model.dataset_statistics["bridge_dataset"], normalization_type="normal"
)
env = HistoryWrapper(env, FLAGS.horizon)
env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon)
env = HistoryWrapper(env, FLAGS.window_size)
env = TemporalEnsembleWrapper(env, FLAGS.action_horizon)
# switch TemporalEnsembleWrapper with RHCWrapper for receding horizon control
# env = RHCWrapper(env, FLAGS.exec_horizon)
# env = RHCWrapper(env, FLAGS.action_horizon)

# create policy function
@jax.jit
# create policy functions
def sample_actions(
pretrained_model: OctoModel,
observations,
Expand All @@ -129,22 +122,19 @@ def sample_actions(
observations,
tasks,
rng=rng,
unnormalization_statistics=pretrained_model.dataset_statistics[
"bridge_dataset"
]["action"],
)
# remove batch dim
return actions[0]

def supply_rng(f, rng=jax.random.PRNGKey(0)):
def wrapped(*args, **kwargs):
nonlocal rng
rng, key = jax.random.split(rng)
return f(*args, rng=key, **kwargs)

return wrapped

policy_fn = supply_rng(
partial(
sample_actions,
model,
argmax=FLAGS.deterministic,
temperature=FLAGS.temperature,
)
)

Expand Down
271 changes: 18 additions & 253 deletions examples/05_dataloading.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/06_pytorch_oxe_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __len__(self):
traj_transform_kwargs=dict(
goal_relabeling_strategy="uniform",
window_size=2,
future_action_window_size=3,
action_horizon=4,
subsample_length=100,
),
frame_transform_kwargs=dict(
Expand Down
2 changes: 0 additions & 2 deletions examples/envs/widowx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ def convert_obs(obs, im_size):
# NOTE: assume image_1 is not available
return {
"image_primary": image_obs,
"proprio": proprio,
}


def null_obs(img_size):
return {
"image_primary": np.zeros((img_size, img_size, 3), dtype=np.uint8),
"proprio": np.zeros((8,), dtype=np.float64),
}


Expand Down
Loading
Loading