Skip to content

Commit

Permalink
classifier free guidance
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 15, 2022
1 parent 6a4fa42 commit 4a8991b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ model = RT1(
depth = 6,
heads = 8,
dim_head = 64,
cond_drop_prob = 0.25 # classifier free guidance conditional dropout
)

video = torch.randn(1, 3, 6, 224, 224)
instructions = ['bring me that apple sitting on the table']

pred = model(video, instructions)
pred = model(video, instructions, cond_scale = 3) # classifier free guidance by scale of 3 times. 1 means disabled
pred.shape # (1, 6, 11, 256) # (batch, frames, actions, bins)
```

Expand All @@ -51,7 +52,7 @@ pred.shape # (1, 6, 11, 256) # (batch, frames, actions, bins)

## Todo

- [ ] add classifier free guidance option
- [x] add classifier free guidance option
- [ ] add cross attention based conditioning

## Citations
Expand Down
25 changes: 21 additions & 4 deletions robotic_transformer_pytorch/robotic_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from functools import partial

from classifier_free_guidance_pytorch import TextConditioner
from classifier_free_guidance_pytorch import TextConditioner, classifier_free_guidance

# helpers

Expand Down Expand Up @@ -313,13 +313,14 @@ def forward(
self,
x,
texts: Optional[List[str]] = None,
cond_drop_prob = 0.,
return_embeddings = False
):
x = self.conv_stem(x)

cond_fns = (None,) * len(self.layers)
if exists(texts):
cond_fns = self.conditioner(texts)
cond_fns = self.conditioner(texts, cond_drop_prob = cond_drop_prob)

for stage, cond_fn in zip(self.layers, cond_fns):
if exists(cond_fn):
Expand Down Expand Up @@ -495,6 +496,7 @@ def __init__(
token_learner_ff_mult = 2,
token_learner_num_layers = 2,
token_learner_num_output_tokens = 8,
cond_drop_prob = 0.2
):
super().__init__()
self.vit = vit
Expand All @@ -515,19 +517,34 @@ def __init__(
depth = depth
)

self.cond_drop_prob = cond_drop_prob

self.to_logits = nn.Sequential(
nn.LayerNorm(vit.embed_dim),
nn.Linear(vit.embed_dim, num_actions * action_bins),
Rearrange('... (a b) -> ... a b', b = action_bins)
)

def forward(self, video, texts: Optional[List[str]] = None):
@classifier_free_guidance
def forward(
self,
video,
texts: Optional[List[str]] = None,
cond_drop_prob = 0.
):
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

frames, device = video.shape[2], video.device

video = rearrange(video, 'b c f h w -> b f c h w')
images, packed_shape = pack_one(video, '* c h w')

tokens = self.vit(images, texts = texts, return_embeddings = True)
tokens = self.vit(
images,
texts = texts,
cond_drop_prob = cond_drop_prob,
return_embeddings = True
)

tokens = unpack_one(tokens, packed_shape, '* c h w')
learned_tokens = self.token_learner(tokens)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
'robotics'
],
install_requires=[
'classifier-free-guidance-pytorch>=0.0.3',
'classifier-free-guidance-pytorch>=0.0.7',
'einops>=0.6',
'torch>=1.6',
],
Expand Down

0 comments on commit 4a8991b

Please sign in to comment.