Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

There are issues with support for ConvTranspose2d #946

Open
mortal-Zero opened this issue Sep 5, 2024 · 3 comments
Open

There are issues with support for ConvTranspose2d #946

mortal-Zero opened this issue Sep 5, 2024 · 3 comments

Comments

@mortal-Zero
Copy link

Hello, and thank you for your outstanding project.
I encountered an error when converting a structure containing ConvTranspose2d using torch2trt. Here is the code and the error.

import torch
import torch.nn as nn
from torch2trt import torch2trt

model = nn.Sequential(
    nn.ConvTranspose2d(in_channels=32, out_channels=64,
                       kernel_size=4, stride=2,
                       padding=1, bias=True),
    nn.BatchNorm2d(num_features=64),
    nn.LeakyReLU()
)
model.to("cuda:0").eval()
x = torch.zeros([1, 32, 16, 16]).to("cuda:0")
y = model(x)
print("=====>> input: {} || output: {}".format(x.shape, y.shape))
model_trt = torch2trt(model, [x])
=====>> input: torch.Size([1, 32, 16, 16]) || output: torch.Size([1, 64, 32, 32])
[09/05/2024-11:09:34] [TRT] [E] 3: 0:0:DECONVOLUTION:GPU:kernel weights has count 32768 but 16384 was expected
[09/05/2024-11:09:34] [TRT] [E] 4: 0:0:DECONVOLUTION:GPU: count of 32768 weights in kernel, but kernel dimensions (4,4) with 32 input channels, 32 output channels and 1 groups were specified. Expected Weights count is 32 * 4*4 * 32 / 1 = 16384
[09/05/2024-11:09:34] [TRT] [E] 4: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2212] Error Code 4: Internal Error (0:0:DECONVOLUTION:GPU: output shape can not be computed)
[09/05/2024-11:09:34] [TRT] [E] 3: [network.cpp::addScaleNd::1162] Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addScaleNd::1162, condition: qdqScale || basicScale
)
Traceback (most recent call last):
  File "/workspace/baiyixuan/test_cvcuda/digitalhuman_service/debug_codes/test.py", line 30, in <module>
    model_trt = torch2trt(model, [x])
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/torch2trt.py", line 643, in torch2trt
    outputs = module(*inputs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward
    return F.batch_norm(
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/torch2trt.py", line 262, in wrapper
    converter["converter"](ctx)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/converters/native_converters.py", line 183, in convert_batch_norm
    output._trt = layer.get_output(0)
AttributeError: 'NoneType' object has no attribute 'get_output'

Looking forward to your reply.

@mortal-Zero
Copy link
Author

Oh yes, I can execute the following code correctly.

import torch
import torch.nn as nn
from torch2trt import torch2trt

model = nn.Sequential(
    nn.Conv2d(in_channels=6, out_channels=32,
              kernel_size=3, stride=1,
              padding=1, bias=True),
    nn.BatchNorm2d(num_features=32),
    nn.LeakyReLU(),
)
model.to("cuda:0").eval()
x = torch.zeros([1, 6, 96, 96]).to("cuda:0")
y = model(x)
print("=====>> input: {} || output: {}".format(x.shape, y.shape))
model_trt = torch2trt(model, [x])

@fasogbon
Copy link

Any fix please? I am having same problem

@qiangxinglin
Copy link

I guess the issue is relevant here:

out_channels = int(weight.shape[0])

Change out_channels = int(weight.shape[0]) to out_channels = int(weight.shape[1]) and reinstall the package solved the issue.
Output channel is weight.shape[1] for ConvTransposexd, and weight.shape[0] for Convxd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants