Skip to content

Commit

Permalink
more informative error message
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 1, 2024
1 parent b1e64ee commit 0acca12
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
64 changes: 39 additions & 25 deletions robotic_transformer_pytorch/robotic_transformer_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import torch
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch import nn, einsum, Tensor

from typing import List, Optional, Callable, Tuple
from typing import Callable
from beartype import beartype

from einops import pack, unpack, repeat, reduce, rearrange
Expand Down Expand Up @@ -42,15 +45,15 @@ def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch

# helper classes

class Residual(nn.Module):
class Residual(Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x):
return self.fn(x) + x

class LayerNorm(nn.Module):
class LayerNorm(Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
Expand All @@ -59,7 +62,7 @@ def __init__(self, dim):
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
Expand All @@ -83,7 +86,7 @@ def forward(self, x, cond_fn = None):

# MBConv

class SqueezeExcitation(nn.Module):
class SqueezeExcitation(Module):
def __init__(self, dim, shrinkage_rate = 0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
Expand All @@ -101,7 +104,7 @@ def forward(self, x):
return x * self.gate(x)


class MBConvResidual(nn.Module):
class MBConvResidual(Module):
def __init__(self, fn, dropout = 0.):
super().__init__()
self.fn = fn
Expand All @@ -112,7 +115,7 @@ def forward(self, x):
out = self.dropsample(out)
return out + x

class Dropsample(nn.Module):
class Dropsample(Module):
def __init__(self, prob = 0):
super().__init__()
self.prob = prob
Expand Down Expand Up @@ -157,7 +160,7 @@ def MBConv(

# attention related classes

class Attention(nn.Module):
class Attention(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -259,7 +262,7 @@ def forward(self, x):
out = self.to_out(out)
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)

class MaxViT(nn.Module):
class MaxViT(Module):
def __init__(
self,
*,
Expand Down Expand Up @@ -294,7 +297,7 @@ def __init__(
dims = (dim_conv_stem, *dims)
dim_pairs = tuple(zip(dims[:-1], dims[1:]))

self.layers = nn.ModuleList([])
self.layers = ModuleList([])

# shorthand for window size for efficient block - grid like attention

Expand Down Expand Up @@ -349,8 +352,8 @@ def __init__(
def forward(
self,
x,
texts: Optional[List[str]] = None,
cond_fns: Optional[Tuple[Callable, ...]] = None,
texts: list[str] | None = None,
cond_fns: tuple[Callable, ...] | None = None,
cond_drop_prob = 0.,
return_embeddings = False
):
Expand All @@ -373,7 +376,7 @@ def forward(

# attention

class TransformerAttention(nn.Module):
class TransformerAttention(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -411,7 +414,7 @@ def forward(
mask = None,
attn_bias = None,
attn_mask = None,
cond_fn: Optional[Callable] = None
cond_fn: Callable | None = None
):
b = x.shape[0]

Expand Down Expand Up @@ -457,8 +460,8 @@ def forward(
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

@beartype
class Transformer(nn.Module):
class Transformer(Module):
@beartype
def __init__(
self,
dim,
Expand All @@ -469,17 +472,18 @@ def __init__(
ff_dropout = 0.
):
super().__init__()
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
self.layers.append(ModuleList([
TransformerAttention(dim = dim, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, dropout = ff_dropout)
]))

@beartype
def forward(
self,
x,
cond_fns: Optional[Tuple[Callable, ...]] = None,
cond_fns: tuple[Callable, ...] | None = None,
attn_mask = None
):
cond_fns = iter(default(cond_fns, []))
Expand All @@ -491,7 +495,7 @@ def forward(

# token learner module

class TokenLearner(nn.Module):
class TokenLearner(Module):
"""
https://arxiv.org/abs/2106.11297
using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map
Expand Down Expand Up @@ -529,8 +533,8 @@ def forward(self, x):

# Robotic Transformer

@beartype
class RT1(nn.Module):
class RT1(Module):
@beartype
def __init__(
self,
*,
Expand Down Expand Up @@ -587,18 +591,28 @@ def __init__(
Rearrange('... (a b) -> ... a b', b = action_bins)
)

def embed_texts(self, texts: List[str]):
@beartype
def embed_texts(self, texts: list[str]):
return self.conditioner.embed_texts(texts)

@classifier_free_guidance
@beartype
def forward(
self,
video,
texts: Optional[List[str]] = None,
text_embeds: Optional[Tensor] = None,
texts: list[str] | None = None,
text_embeds: Tensor | None = None,
cond_drop_prob = 0.
):
assert exists(texts) ^ exists(text_embeds)

if exists(texts):
num_texts = len(texts)
elif exists(text_embeds):
num_texts = text_embeds.shape[0]

assert num_texts == video.shape[0], f'you only passed in {num_texts} strings for guiding the robot actions, but received batch size of {video.shape[0]} videos'

cond_kwargs = dict(texts = texts, text_embeds = text_embeds)

depth = self.transformer_depth
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'robotic-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.1',
version = '0.2.2',
license='MIT',
description = 'Robotic Transformer - Pytorch',
author = 'Phil Wang',
Expand All @@ -19,7 +19,7 @@
],
install_requires=[
'classifier-free-guidance-pytorch>=0.4.0',
'einops>=0.7',
'einops>=0.8',
'torch>=2.0',
],
classifiers=[
Expand Down

0 comments on commit 0acca12

Please sign in to comment.