Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistencies in MBConv, with corrected code provided #19

Open
swarajnanda2021 opened this issue Mar 8, 2024 · 3 comments
Open

Inconsistencies in MBConv, with corrected code provided #19

swarajnanda2021 opened this issue Mar 8, 2024 · 3 comments

Comments

@swarajnanda2021
Copy link

swarajnanda2021 commented Mar 8, 2024

I've found the MBConv to have some computational inconsistencies. The following corrected code works, where I've changed the stride of the projection operation (self.proj) and moved it out of the if downsample statement. Further, the squeeze and excite block has been appropriately initialized (I've added my squeeze and excite block too here for completeness). I've also added the channel projection operation on the downsample is false branch of MBConv forward method:

class SqueezeAndExcite(nn.Module):
    def __init__(self, in_channels, expansion=0.25): # keep the reduction fixed
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, int(in_channels * expansion)),
            nn.GELU(),
            nn.Linear(int(in_channels * expansion), in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class MBConv(nn.Module):
    def __init__(self, inp, oup, expansion, downsample):
        super().__init__()
        self.downsample = downsample
        stride = 1 if not downsample else 2
        hidden_dim = int(expansion * inp)

        if self.downsample:
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.proj = nn.Conv2d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, 
                          padding=1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(oup)
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(inp, hidden_dim, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, 
                          padding=1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SqueezeAndExcite(hidden_dim, expansion=0.25),
                nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(oup)
            )

        self.conv = PreNorm(norm=nn.BatchNorm2d, model=self.conv, dimension=inp)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return self.proj(x) + self.conv(x)

@Uljibuh
Copy link

Uljibuh commented Mar 14, 2024

hi! just curious that did you run into this issue? how did you solved it?

https://github.com/chinhsuanwu/coatnet-pytorch/issues/20

@swarajnanda2021
Copy link
Author

I was implementing CoATNet myself and sought this repo for inspiration. It did not work, so while debugging I had to re-read the paper several times. Finally I understood the problems and accordingly found a solution. Of course, GPT4 helped a lot here.

@Uljibuh
Copy link

Uljibuh commented Mar 20, 2024

how was the training results of the model? did you use downsampling ? which one gives better results? with downsamling or without downsampling?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants