Skip to content

Commit

Permalink
add a 1d sinusoidal positional embedding before last transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 16, 2022
1 parent b0e0fcf commit 38f0077
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ eval_logits = model(video, instructions, cond_scale = 3.) # classifier free guid

```bibtex
@inproceedings{rt12022arxiv,
title = {RT-1: Robotics Transformer for Real-World Control at Scale},
author={Anthony Brohan and Noah Brown and Justice Carbajal and Yevgen Chebotar and Joseph Dabis and Chelsea Finn and Keerthana Gopalakrishnan and Karol Hausman and Alex Herzog and Jasmine Hsu and Julian Ibarz and Brian Ichter and Alex Irpan and Tomas Jackson and Sally Jesmonth and Nikhil Joshi and Ryan Julian and Dmitry Kalashnikov and Yuheng Kuang and Isabel Leal and Kuang-Huei Lee and Sergey Levine and Yao Lu and Utsav Malla and Deeksha Manjunath and Igor Mordatch and Ofir Nachum and Carolina Parada and Jodilyn Peralta and Emily Perez and Karl Pertsch and Jornell Quiambao and Kanishka Rao and Michael Ryoo and Grecia Salazar and Pannag Sanketi and Kevin Sayed and Jaspiar Singh and Sumedh Sontakke and Austin Stone and Clayton Tan and Huong Tran and Vincent Vanhoucke and Steve Vega and Quan Vuong and Fei Xia and Ted Xiao and Peng Xu and Sichun Xu and Tianhe Yu and Brianna Zitkovich},
title = {RT-1: Robotics Transformer for Real-World Control at Scale},
author = {Anthony Brohan and Noah Brown and Justice Carbajal and Yevgen Chebotar and Joseph Dabis and Chelsea Finn and Keerthana Gopalakrishnan and Karol Hausman and Alex Herzog and Jasmine Hsu and Julian Ibarz and Brian Ichter and Alex Irpan and Tomas Jackson and Sally Jesmonth and Nikhil Joshi and Ryan Julian and Dmitry Kalashnikov and Yuheng Kuang and Isabel Leal and Kuang-Huei Lee and Sergey Levine and Yao Lu and Utsav Malla and Deeksha Manjunath and Igor Mordatch and Ofir Nachum and Carolina Parada and Jodilyn Peralta and Emily Perez and Karl Pertsch and Jornell Quiambao and Kanishka Rao and Michael Ryoo and Grecia Salazar and Pannag Sanketi and Kevin Sayed and Jaspiar Singh and Sumedh Sontakke and Austin Stone and Clayton Tan and Huong Tran and Vincent Vanhoucke and Steve Vega and Quan Vuong and Fei Xia and Ted Xiao and Peng Xu and Sichun Xu and Tianhe Yu and Brianna Zitkovich},
booktitle = {arXiv preprint arXiv:2204.01691},
year = {2022}
year = {2022}
}
```

Expand Down
21 changes: 21 additions & 0 deletions robotic_transformer_pytorch/robotic_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ def pack_one(x, pattern):
def unpack_one(x, ps, pattern):
return unpack(x, ps, pattern)[0]

# sinusoidal positions

def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32):
n = torch.arange(seq, device = device)
omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
omega = 1. / (temperature ** omega)

n = n[:, None] * omega[None, :]
pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
return pos_emb.type(dtype)

# helper classes

class Residual(nn.Module):
Expand Down Expand Up @@ -560,9 +571,19 @@ def forward(

learned_tokens = rearrange(learned_tokens, 'b f c n -> b (f n) c')

# causal attention mask

attn_mask = torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1)
attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens)

# sinusoidal positional embedding

pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], dtype = learned_tokens.dtype, device = learned_tokens.device)

learned_tokens = learned_tokens + repeat(pos_emb, 'n d -> (n r) d', r = self.num_learned_tokens)

# attention

attended_tokens = self.transformer(learned_tokens, attn_mask = ~attn_mask)

pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'robotic-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'Robotic Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 38f0077

Please sign in to comment.