diff --git a/robotic_transformer_pytorch/robotic_transformer_pytorch.py b/robotic_transformer_pytorch/robotic_transformer_pytorch.py index 73c1159..6d8ebd6 100644 --- a/robotic_transformer_pytorch/robotic_transformer_pytorch.py +++ b/robotic_transformer_pytorch/robotic_transformer_pytorch.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import torch +from torch.nn import Module, ModuleList import torch.nn.functional as F from torch import nn, einsum, Tensor -from typing import List, Optional, Callable, Tuple +from typing import Callable from beartype import beartype from einops import pack, unpack, repeat, reduce, rearrange @@ -42,7 +45,7 @@ def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch # helper classes -class Residual(nn.Module): +class Residual(Module): def __init__(self, fn): super().__init__() self.fn = fn @@ -50,7 +53,7 @@ def __init__(self, fn): def forward(self, x): return self.fn(x) + x -class LayerNorm(nn.Module): +class LayerNorm(Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.ones(dim)) @@ -59,7 +62,7 @@ def __init__(self, dim): def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) -class FeedForward(nn.Module): +class FeedForward(Module): def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() inner_dim = int(dim * mult) @@ -83,7 +86,7 @@ def forward(self, x, cond_fn = None): # MBConv -class SqueezeExcitation(nn.Module): +class SqueezeExcitation(Module): def __init__(self, dim, shrinkage_rate = 0.25): super().__init__() hidden_dim = int(dim * shrinkage_rate) @@ -101,7 +104,7 @@ def forward(self, x): return x * self.gate(x) -class MBConvResidual(nn.Module): +class MBConvResidual(Module): def __init__(self, fn, dropout = 0.): super().__init__() self.fn = fn @@ -112,7 +115,7 @@ def forward(self, x): out = self.dropsample(out) return out + x -class Dropsample(nn.Module): +class Dropsample(Module): def __init__(self, prob = 0): super().__init__() self.prob = prob @@ -157,7 +160,7 @@ def MBConv( # attention related classes -class Attention(nn.Module): +class Attention(Module): def __init__( self, dim, @@ -259,7 +262,7 @@ def forward(self, x): out = self.to_out(out) return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width) -class MaxViT(nn.Module): +class MaxViT(Module): def __init__( self, *, @@ -294,7 +297,7 @@ def __init__( dims = (dim_conv_stem, *dims) dim_pairs = tuple(zip(dims[:-1], dims[1:])) - self.layers = nn.ModuleList([]) + self.layers = ModuleList([]) # shorthand for window size for efficient block - grid like attention @@ -349,8 +352,8 @@ def __init__( def forward( self, x, - texts: Optional[List[str]] = None, - cond_fns: Optional[Tuple[Callable, ...]] = None, + texts: list[str] | None = None, + cond_fns: tuple[Callable, ...] | None = None, cond_drop_prob = 0., return_embeddings = False ): @@ -373,7 +376,7 @@ def forward( # attention -class TransformerAttention(nn.Module): +class TransformerAttention(Module): def __init__( self, dim, @@ -411,7 +414,7 @@ def forward( mask = None, attn_bias = None, attn_mask = None, - cond_fn: Optional[Callable] = None + cond_fn: Callable | None = None ): b = x.shape[0] @@ -457,8 +460,8 @@ def forward( out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) -@beartype -class Transformer(nn.Module): +class Transformer(Module): + @beartype def __init__( self, dim, @@ -469,17 +472,18 @@ def __init__( ff_dropout = 0. ): super().__init__() - self.layers = nn.ModuleList([]) + self.layers = ModuleList([]) for _ in range(depth): - self.layers.append(nn.ModuleList([ + self.layers.append(ModuleList([ TransformerAttention(dim = dim, heads = heads, dropout = attn_dropout), FeedForward(dim = dim, dropout = ff_dropout) ])) + @beartype def forward( self, x, - cond_fns: Optional[Tuple[Callable, ...]] = None, + cond_fns: tuple[Callable, ...] | None = None, attn_mask = None ): cond_fns = iter(default(cond_fns, [])) @@ -491,7 +495,7 @@ def forward( # token learner module -class TokenLearner(nn.Module): +class TokenLearner(Module): """ https://arxiv.org/abs/2106.11297 using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map @@ -529,8 +533,8 @@ def forward(self, x): # Robotic Transformer -@beartype -class RT1(nn.Module): +class RT1(Module): + @beartype def __init__( self, *, @@ -587,18 +591,28 @@ def __init__( Rearrange('... (a b) -> ... a b', b = action_bins) ) - def embed_texts(self, texts: List[str]): + @beartype + def embed_texts(self, texts: list[str]): return self.conditioner.embed_texts(texts) @classifier_free_guidance + @beartype def forward( self, video, - texts: Optional[List[str]] = None, - text_embeds: Optional[Tensor] = None, + texts: list[str] | None = None, + text_embeds: Tensor | None = None, cond_drop_prob = 0. ): assert exists(texts) ^ exists(text_embeds) + + if exists(texts): + num_texts = len(texts) + elif exists(text_embeds): + num_texts = text_embeds.shape[0] + + assert num_texts == video.shape[0], f'you only passed in {num_texts} strings for guiding the robot actions, but received batch size of {video.shape[0]} videos' + cond_kwargs = dict(texts = texts, text_embeds = text_embeds) depth = self.transformer_depth diff --git a/setup.py b/setup.py index b8066eb..f3bb0a2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'robotic-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.2.1', + version = '0.2.2', license='MIT', description = 'Robotic Transformer - Pytorch', author = 'Phil Wang', @@ -19,7 +19,7 @@ ], install_requires=[ 'classifier-free-guidance-pytorch>=0.4.0', - 'einops>=0.7', + 'einops>=0.8', 'torch>=2.0', ], classifiers=[