Skip to content

Commit

Permalink
adapted sample_action/trajectory to work with new octo version
Browse files Browse the repository at this point in the history
  • Loading branch information
andrearosasco committed Jul 5, 2024
1 parent 07cc35f commit 014f5f1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,9 @@ def is_nonzero_length(traj):
full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
if ignore_errors:
full_dataset = full_dataset.ignore_errors()
# for traj in full_dataset:
# restructure(traj)

full_dataset = full_dataset.traj_map(restructure).filter(is_nonzero_length)

# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
Expand Down
8 changes: 4 additions & 4 deletions octo/model/octo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def sample_future_actions(
unnormalization_statistics: Optional[Data] = None,
normalization_type: NormalizationType = NormalizationType.NORMAL,
beam:int = 1,
pad_mask: Optional[ArrayLike] = None,
timestep_pad_mask: Optional[ArrayLike] = None,
train: bool = False,
argmax: bool = False,
sample_shape: Tuple[int, ...] = (),
Expand All @@ -281,7 +281,7 @@ def sample_future_actions(
actions: (*sample_shape, batch_size, action_horizon, action_dim)
"""
if timestep_pad_mask is None:
timestep_pad_mask = observations["timestep_pad_mask"]
timestep_pad_mask = observations["pad_mask"]

transformer_outputs = self.run_transformer(
observations, tasks, timestep_pad_mask, train=train
Expand Down Expand Up @@ -346,14 +346,14 @@ def sample_trajectory(
unnormalization_statistics: Optional[Data] = None,
normalization_type: NormalizationType = NormalizationType.NORMAL,
beam: int = 1,
pad_mask: Optional[ArrayLike] = None,
timestep_pad_mask: Optional[ArrayLike] = None,
train: bool = False,
argmax: bool = False,
sample_shape: Tuple[int, ...] = (),
rng: Optional[PRNGKey] = None,
temperature: float = 1.0,
):
if pad_mask is None:
if timestep_pad_mask is None:
pad_mask = observations["pad_mask"]

transformer_outputs = self.run_transformer(
Expand Down

0 comments on commit 014f5f1

Please sign in to comment.