Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_trt.shape overflow when converting from pytorch to TRT #927

Open
Justin020718 opened this issue May 26, 2024 · 5 comments
Open

_trt.shape overflow when converting from pytorch to TRT #927

Justin020718 opened this issue May 26, 2024 · 5 comments

Comments

@Justin020718
Copy link

Justin020718 commented May 26, 2024

before operator converting, some input tensors had attributes called '_trt'
屏幕截图(240)
To deal with it, I've deleted these incorrect '_trt' attributes manually in related converter function in “naive_converters.py”.
However, I've encountered sone errors when doing things as follows, and it says:

[05/26/2024-22:35:09] [TRT] [E] 3: context_fusion_net.conv3_up.0:1:CONVOLUTION:GPU: at least 4 dimensions are required for input.
[05/26/2024-22:35:09] [TRT] [E] 3: context_fusion_net.conv3_up.0:1:CONVOLUTION:GPU: at least 4 dimensions are required for input.
[05/26/2024-22:35:09] [TRT] [E] 4: [network.cpp::nvinfer1::Network::validate::3478] Error Code 4: Internal Error (Layer >context_fusion_net.conv3_up.0:1:CONVOLUTION:GPU failed validation)

Of course, the TRT module was not successfully generated, and when I try to save its state dict, it says:

Traceback (most recent call last):
File "E:\DCVC-main\DCVC-DC\quant.py", line 497, in quant_for_p_all
torch.save(p_frame_enc_trt.state_dict(), path)
File "E:\anaconda3\envs\python3.8\lib\site-packages\torch\nn\modules\module.py", line 1918, in state_dict
hook_result = hook(self, destination, prefix, local_metadata)
File "E:\anaconda3\envs\python3.8\lib\site-packages\torch2trt-0.5.0-py3.8.egg\torch2trt\trt_module.py", line 60, in _on_state_dict
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
AttributeError: 'NoneType' object has no attribute 'serialize'

Finally, here's my environment:
pytorch 2.3.0+cu121
torch2trt 0.5.0
tensorRT 8.6.1.6 (Also tried tensorRT 10.0.1.6, after renaming some covnerting method, I have the same issue)

Since it's my first issue on Github, I really appreciate it if someone could help, or offer some clue to solve it, please!!!!!!!!!!!

@Justin020718
Copy link
Author

Here is my changes in "naive_converters.py":
屏幕截图(241)
Despite unpleasant, I have to change like that otherwise I'll encounter a "ValueError: len() should return >= 0" issue.
The module works well in pytorch.

@Justin020718
Copy link
Author

I've generated trt engine successfully through ONNX.

@jaybdub
Copy link
Contributor

jaybdub commented May 30, 2024

@Justin020718 ,

Thanks for reaching out. Apologies for the delay.

Glad to hear ONNX worked for you. FYI, not sure if this is what you did, but you can run torch2trt(..., use_onnx=True) for convenience.

In case the issue re-appears / I need to reproduce, is this a publicly available model?

Best,
John

@Justin071802
Copy link

@jaybdub
Thank you for your reply.
microsoft/DCVC#18
I will try it later

@Justin020718
Copy link
Author

@jaybdub
Forgot to say, to reproduce my issue, you might need to add some converters in "native_converters.py"

@tensorrt_converter('torch.nn.functional.pixel_shuffle')
def convert_PixelShuffle(ctx):
    input = ctx.method_args[0]
    scale_factor = ctx.method_args[1]
    input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
    output = ctx.method_return

    batch_size, in_channels, in_height, in_width = input.shape

    assert scale_factor >= 1

    out_channels = in_channels // (scale_factor * scale_factor)
    out_height = in_height * scale_factor
    out_width = in_width * scale_factor

    layer_1 = ctx.network.add_shuffle(input_trt)
    layer_1.reshape_dims = (out_channels, scale_factor, scale_factor, in_height, in_width)

    layer_2 = ctx.network.add_shuffle(layer_1.get_output(0))
    layer_2.first_transpose = (0, 3, 1, 4, 2)
    layer_2.reshape_dims = (batch_size, out_channels, out_height, out_width)

    output._trt = layer_2.get_output(0)

def _set_layer_precision(ctx, layer):
    # Supported TRT precisions as given by torch2trt_kwargs.
    INT8_MODE = "int8_mode"
    FP16_MODE = "fp16_mode"

    # Check that args exist as expected in torch2trt_kwargs.
    trt_kwargs = ctx.torch2trt_kwargs
    assert INT8_MODE in trt_kwargs
    assert FP16_MODE in trt_kwargs

    is_int8 = trt_kwargs.get(INT8_MODE, False)
    is_fp16 = trt_kwargs.get(FP16_MODE, False)

    if is_int8:
        layer.precision = trt.int8
        layer.set_output_type(0, trt.int8)
    elif is_fp16:
        layer.precision = trt.float16
        layer.set_output_type(0, trt.float16)
        
@tensorrt_converter('torch.zeros')
def convert_zeros(ctx):
    tensor = ctx.method_return

    # Implementation copied from add_trt_constant.
    shape = tuple(tensor.shape[1:])
    array = tensor[0].detach().cpu().numpy()
    layer = ctx.network.add_constant(shape, array)

    _set_layer_precision(ctx, layer)

    tensor._trt = layer.get_output(0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants