diff --git a/robotic_transformer_pytorch/robotic_transformer_pytorch.py b/robotic_transformer_pytorch/robotic_transformer_pytorch.py index 6134bdd..32fca10 100644 --- a/robotic_transformer_pytorch/robotic_transformer_pytorch.py +++ b/robotic_transformer_pytorch/robotic_transformer_pytorch.py @@ -29,6 +29,17 @@ def pack_one(x, pattern): def unpack_one(x, ps, pattern): return unpack(x, ps, pattern)[0] +# sinusoidal positions + +def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32): + n = torch.arange(seq, device = device) + omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1) + omega = 1. / (temperature ** omega) + + n = n[:, None] * omega[None, :] + pos_emb = torch.cat((n.sin(), n.cos()), dim = 1) + return pos_emb.type(dtype) + # helper classes class Residual(nn.Module): @@ -560,9 +571,19 @@ def forward( learned_tokens = rearrange(learned_tokens, 'b f c n -> b (f n) c') + # causal attention mask + attn_mask = torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1) attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens) + # sinusoidal positional embedding + + pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], dtype = learned_tokens.dtype, device = learned_tokens.device) + + learned_tokens = learned_tokens + repeat(pos_emb, 'n d -> (n r) d', r = self.num_learned_tokens) + + # attention + attended_tokens = self.transformer(learned_tokens, attn_mask = ~attn_mask) pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames) diff --git a/setup.py b/setup.py index 6299480..f8ca8c4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'robotic-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.0.8', + version = '0.0.9', license='MIT', description = 'Robotic Transformer - Pytorch', author = 'Phil Wang',