-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update with many small changes/fixes #92
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly checking that the examples are still consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't verify whether all the proprio computations make sense for the OpenX datasets -- did you use the info from the former configs to figure this out?
@@ -202,7 +203,7 @@ def __call__( | |||
assert self.proper_pad_mask, "Cannot skip unless using proper pad mask." | |||
return None | |||
|
|||
if not isinstance(tasks["language_instruction"], jax.Array): | |||
if not isinstance(tasks["language_instruction"], (jax.Array, np.ndarray)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this ever an np array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SudeepDasari I think this got added in the new action head PR?
LGTM! |
This PR adds many small changes/fixes we've made internally.
use_correct_attention
flag to maintain backwards compatibilitysupply_rng
function wrapper that can be used to wrap the model'ssample_actions
function to ensure that a fresh random key is used for each callimage_augment_kwargs
was not a dict