Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal Error failure of TensorRT 8.6.1 when running tensorrt.OnnxParser.parse_from_file on GPU NVIDIA GeForce RTX 3060 #3728

Closed
dspyz-matician opened this issue Mar 21, 2024 · 4 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@dspyz-matician
Copy link

dspyz-matician commented Mar 21, 2024

Description

I have a simple model I saved to onnx and then the OnnxParser failed with an ICE:

[03/21/2024-12:12:47] [TRT] [E] ModelImporter.cpp:774: --- End node ---
[03/21/2024-12:12:47] [TRT] [E] ModelImporter.cpp:777: ERROR: ModelImporter.cpp:195 In function parseGraph:
[6] Invalid Node - /Where
[graphShapeAnalyzer.cpp::checkCalculationStatusSanity::1503] Error Code 2: Internal Error (Assertion !isPartialWork(p.second.outputExtents) failed. )

Code is below; ONNX file is attached

import torch
import torch.nn as nn
import torch.onnx

MAX_TRAV = 21


class Traversability(nn.Module):
    def forward(self, untraversable):
        max_trav_scalar = torch.tensor(
            MAX_TRAV, dtype=torch.float32, device=untraversable.device
        )
        sq_distances = torch.where(untraversable, 0, max_trav_scalar**2)
        extra_column = torch.full(
            (sq_distances.shape[0], 1),
            max_trav_scalar**2,
            device=sq_distances.device,
        )
        # horizontal pass
        for d in range(1, MAX_TRAV * 2 + 1, 2):
            d = torch.tensor(d, dtype=torch.int32, device=sq_distances.device)
            sq_distances = torch.minimum(
                sq_distances,
                torch.minimum(
                    torch.cat([sq_distances[:, 1:] + d, extra_column], 1),
                    torch.cat([extra_column, sq_distances[:, :-1] + d], 1),
                ),
            )
        extra_row = torch.full(
            (1, sq_distances.shape[1]),
            max_trav_scalar**2,
            device=sq_distances.device,
        )
        # vertical pass
        for d in range(1, MAX_TRAV * 2 + 1, 2):
            d = torch.tensor(d, dtype=torch.int32, device=sq_distances.device)
            sq_distances = torch.minimum(
                sq_distances,
                torch.minimum(
                    torch.cat([sq_distances[1:, :] + d, extra_row], 0),
                    torch.cat([extra_row, sq_distances[:-1, :] + d], 0),
                ),
            )
        return sq_distances.sqrt()

Environment

TensorRT Version: 8.6.1

NVIDIA GPU: NVIDIA GeForce RTX 3060

NVIDIA Driver Version: 550.54.14

CUDA Version: 12.4

CUDNN Version: 8.9.7

Operating System: Debian 11

Python Version (if applicable): 3.9.2

PyTorch Version (if applicable): 2.2.1+cu121

Relevant Files

Model link:

traversability.zip

Steps To Reproduce

Commands or scripts:

Generating the onnx model:

    # Set the model to evaluation mode
    model.eval()

    # Define the input shape
    dummy_input = torch.randint(0, 2, (224, 224), dtype=torch.bool)

    # Specify the output file path
    onnx_file_path = "traversability.onnx"

    # Convert the model to ONNX format
    torch.onnx.export(
        model,
        dummy_input,
        onnx_file_path,
        input_names=["input"],
        output_names=["output"],
    )

Converting the onnx model:

import tensorrt as trt

# Create a TensorRT logger
logger = trt.Logger(trt.Logger.WARNING)

# Create a TensorRT builder
builder = trt.Builder(logger)

# Create a network
network = builder.create_network(
    1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)

# Parse the ONNX model
parser = trt.OnnxParser(network, logger)
parser.parse_from_file("traversability.onnx")

Can this model run on other frameworks?:
It works just fine when running it directly from Python:

from matplotlib import pyplot as plt

model = Traversability()

untrav = torch.zeros((224, 224), dtype=torch.bool, device=torch.device("cuda"))
untrav[4:40, 70:75] = True
untrav[60:80, 90:95] = True
untrav[20:30, 120:130] = True

trav = model(untrav)

plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(untrav.to(torch.device("cpu")), cmap="gray")
plt.subplot(1, 2, 2)
plt.imshow(trav.to(torch.device("cpu")), cmap="viridis")
plt.show()
@dspyz-matician dspyz-matician changed the title Intenral Error failure of TensorRT 8.6.1 when running tensorrt.OnnxParser.parse_from_file on GPU NVIDIA GeForce RTX 3060 Internal Error failure of TensorRT 8.6.1 when running tensorrt.OnnxParser.parse_from_file on GPU NVIDIA GeForce RTX 3060 Mar 21, 2024
@zerollzeng
Copy link
Collaborator

Looks like a torch-onnx model, it exports an invalid model.

$ polygraphy run traversability.onnx --onnxrt
[I] RUNNING | Command: /home/scratch.zeroz_sw/miniconda3/bin/polygraphy run traversability.onnx --onnxrt
[I] onnxrt-runner-N0-03/23/24-07:56:08  | Activating and starting inference
[I] Creating ONNX-Runtime Inference Session with providers: ['CPUExecutionProvider']
Traceback (most recent call last):
  File "/home/scratch.zeroz_sw/miniconda3/bin/polygraphy", line 8, in <module>
    sys.exit(main())
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/tools/_main.py", line 70, in main
    status = selected_tool.run(args)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/tools/base/tool.py", line 171, in run
    status = self.run_impl(args)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/tools/run/run.py", line 226, in run_impl
    exec(str(script))
  File "<string>", line 21, in <module>
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/comparator/comparator.py", line 211, in run
    run_results.append((runner.name, execute_runner(runner, loader_cache)))
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/comparator/comparator.py", line 96, in execute_runner
    with runner as active_runner:
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/backend/base/runner.py", line 62, in __enter__
    self.activate()
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/backend/base/runner.py", line 97, in activate
    self.activate_impl()
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/util/util.py", line 710, in wrapped
    return func(*args, **kwargs)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/backend/onnxrt/runner.py", line 43, in activate_impl
    self.sess, _ = util.invoke_if_callable(self._sess)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/util/util.py", line 678, in invoke_if_callable
    ret = func(*args, **kwargs)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/backend/base/loader.py", line 40, in __call__
    return self.call_impl(*args, **kwargs)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/util/util.py", line 710, in wrapped
    return func(*args, **kwargs)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/polygraphy/backend/onnxrt/loader.py", line 68, in call_impl
    return onnxrt.InferenceSession(model_bytes, providers=providers)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/scratch.zeroz_sw/miniconda3/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 472, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from /home/scratch.zeroz_sw/github_bug/3728/traversability.onnx failed:Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(int64) and tensor(float) in node (/Where).

@zerollzeng zerollzeng self-assigned this Mar 23, 2024
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Mar 23, 2024
@zerollzeng
Copy link
Collaborator

Could you please file a bug against pytorch? Thanks!

@dspyz-matician
Copy link
Author

Sorry I took so long to get around to this: pytorch/pytorch#123353

@ttyio
Copy link
Collaborator

ttyio commented May 7, 2024

closing since no activity for more than 3 weeks per our policy, thanks all!

@ttyio ttyio closed this as completed May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants