Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update with many small changes/fixes #92

Merged
merged 16 commits into from
May 23, 2024
Merged

Update with many small changes/fixes #92

merged 16 commits into from
May 23, 2024

Conversation

HomerW
Copy link
Collaborator

@HomerW HomerW commented May 13, 2024

This PR adds many small changes/fixes we've made internally.

  • Variable naming changes
  • Dataloader now returns actions that are already chunked
  • Creating the proipio key is now handled by standardization functions
  • New UNet diffusion head
  • Action unnormalization happens in the model now
  • Updated visualizers
  • Fixed bug in attention mask, added use_correct_attention flag to maintain backwards compatibility
  • Added supply_rng function wrapper that can be used to wrap the model's sample_actions function to ensure that a fresh random key is used for each call
  • Fixed issue where image augmentation wasn't being applied because in configs image_augment_kwargs was not a dict

Copy link
Collaborator

@dibyaghosh dibyaghosh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly checking that the examples are still consistent

examples/03_eval_finetuned.py Outdated Show resolved Hide resolved
examples/03_eval_finetuned.py Show resolved Hide resolved
octo/data/oxe/oxe_dataset_mixes.py Outdated Show resolved Hide resolved
octo/data/utils/data_utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@kpertsch kpertsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't verify whether all the proprio computations make sense for the OpenX datasets -- did you use the info from the former configs to figure this out?

octo/data/oxe/oxe_dataset_configs.py Outdated Show resolved Hide resolved
octo/data/oxe/oxe_dataset_mixes.py Outdated Show resolved Hide resolved
octo/data/oxe/oxe_dataset_mixes.py Outdated Show resolved Hide resolved
octo/data/traj_transforms.py Outdated Show resolved Hide resolved
octo/data/utils/data_utils.py Outdated Show resolved Hide resolved
@@ -202,7 +203,7 @@ def __call__(
assert self.proper_pad_mask, "Cannot skip unless using proper pad mask."
return None

if not isinstance(tasks["language_instruction"], jax.Array):
if not isinstance(tasks["language_instruction"], (jax.Array, np.ndarray)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this ever an np array?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SudeepDasari I think this got added in the new action head PR?

@mees
Copy link
Collaborator

mees commented May 23, 2024

LGTM!

@HomerW HomerW merged commit 5eaa5c6 into main May 23, 2024
1 check passed
@HomerW HomerW deleted the new_release branch May 23, 2024 23:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants