Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScatterElements with Reduction (opset 16) Not Fully Supported #3650

Closed
anthony-correia opened this issue Feb 1, 2024 · 6 comments
Closed
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@anthony-correia
Copy link

anthony-correia commented Feb 1, 2024

Short Description

Conversion of an ONNX model to TensorRT using trtexec, which includes a scatterElements operation with a reduction like "sum" (opset 16), fails when the number of indices in the operation exceeds the output count.

Successful conversion requires n_indices <= n_outputs.

Long Description

Consider the following PyTorch model snippet:

import torch
import torch_scatter

n_indices: int = ...
dim_size: int = ...
n_outputs: int = ...

e_dummy = torch.randn(size=(n_indices, dim_size), device=device)
index_dummy = torch.randint(high=n_outputs, size=(n_indices,), device=device)

class ScatterModule(torch.nn.Module):
    def forward(self, e: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
        return torch_scatter.scatter(
            src=e,
            # broadcasting (should be done automatically anyway)
            index=index.unsqueeze(-1).expand(-1, e.shape[1]),
            dim=0,
            reduce="sum",
        )

Converting this corrresponding ONNX model using trtexec triggers an assertion error:

Assertion failed: indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!"

This error likely originates from this line of the ONNX-TensorRT code.

In the scenarios I've encountered within Graph Neural Networks, the number of indices (n_indices, corresponding to the edges in the graph) is significantly larger than the number of outputs (n_outputs, corresponding to the nodes in the graph).

Environment

TensorRT Version: 8.6.1.6-1+cuda11.8. I've also tried the TensorRT release 9.2.
GPU Type: NVIDIA RTX A2000 (laptop)
Nvidia Driver Version: 520.61.05
CUDA Version: 11.8.0-1
CUDNN Version: 8.7.0.84-1+cuda11.8
Operating System + Version: Ubuntu 22.04.1 LTS
PyTorch version: 2.1.2

Relevant Files

I've created a repository to reproduce the issue: anthony-correia/scatter_onnx2tensorrt.

The ONNX models are stored with the naming convention onnx/{n_indices}_{dim_size}_{n_outputs}_{seed}.onnx.

To replicate the issue, execute the following commands:

# This command fails when `n_outputs = 100` and `n_indices = 1000`.
trtexec --onnx="onnx/1000_3_100_0.onnx"

# This command succeeds when `n_outputs` equals `n_indices` (both are 100).
trtexec --onnx="onnx/100_3_100_0.onnx"

Cross-posted in onnx/onnx-tensorrt#953 ; this repository looks more relevant.

@zerollzeng
Copy link
Collaborator

@ttyio Is it in our plan? Thanks!

@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Feb 7, 2024
@ttyio
Copy link
Collaborator

ttyio commented Mar 26, 2024

FYI, we will have scatterElements plugin with reduction support in 10.0.

@anthony-correia
Copy link
Author

anthony-correia commented Mar 26, 2024

Great to hear that! Thanks a lot for letting me/us know!

@Liupei1101
Copy link

later. now do you have method to

FYI, we will have scatterElements plugin with reduction support in 10.0.

when?

@lix19937
Copy link

lix19937 commented Apr 4, 2024

Now https://github.com/NVIDIA/TensorRT/blob/release/10.0/CHANGELOG.md

Parser changes
Added a new class IParserRefitter that can be used to refit a TensorRT engine with the weights of an ONNX model.
kNATIVE_INSTANCENORM is now set to ON by default.
Added support for IPluginV3 interfaces from TensorRT.
Added support for INT4 quantization.
Added support for the reduction attribute in ScatterElements.
Added support for wrap padding mode in Pad

@ttyio
Copy link
Collaborator

ttyio commented Apr 16, 2024

Closing, thanks all!

@ttyio ttyio closed this as completed Apr 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants