Skip to content

Commit

Permalink
stress test classifier free guidance package
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 16, 2022
1 parent 38f0077 commit 5011b9d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ model = RT1(
cond_drop_prob = 0.2
)

video = torch.randn(1, 3, 6, 224, 224)
instructions = ['bring me that apple sitting on the table']
video = torch.randn(2, 3, 6, 224, 224)

train_logits = model(video, instructions) # (1, 6, 11, 256) # (batch, frames, actions, bins)
instructions = [
'bring me that apple sitting on the table',
'please pass the butter'
]

train_logits = model(video, instructions) # (2, 6, 11, 256) # (batch, frames, actions, bins)

# after much training

Expand Down
6 changes: 5 additions & 1 deletion robotic_transformer_pytorch/robotic_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,11 @@ def forward(

cond_fns = (None,) * len(self.layers)
if exists(texts):
cond_fns = self.conditioner(texts, cond_drop_prob = cond_drop_prob)
cond_fns = self.conditioner(
texts,
cond_drop_prob = cond_drop_prob,
repeat_batch = x.shape[0] // len(texts) # text conditionig across multiple frames of video
)

for stage, cond_fn in zip(self.layers, cond_fns):
if exists(cond_fn):
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'robotic-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.9',
version = '0.0.10',
license='MIT',
description = 'Robotic Transformer - Pytorch',
author = 'Phil Wang',
Expand All @@ -18,7 +18,7 @@
'robotics'
],
install_requires=[
'classifier-free-guidance-pytorch>=0.0.14',
'classifier-free-guidance-pytorch>=0.0.15',
'einops>=0.6',
'torch>=1.6',
],
Expand Down

0 comments on commit 5011b9d

Please sign in to comment.