Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

converting to TensorRT barely increases performance #3646

Closed
ninono12345 opened this issue Feb 1, 2024 · 9 comments
Closed

converting to TensorRT barely increases performance #3646

ninono12345 opened this issue Feb 1, 2024 · 9 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@ninono12345
Copy link

ninono12345 commented Feb 1, 2024

Description

Hello everyone

I am working on a pytorch object tracking model to convert it to tensorrt for faster inference

When inferencing tensorrt with a single batch the model is about 2x faster, but when adding batches, it becomes SLOWER

batch of 1 inference time:
pytorch - 40ms
tensorrt - 20ms

batch of 8 inference time:
pytorch - 160ms
tensorrt - 100ms

Shouldn't tensorrt be 5x faster? What can be done to improve this?

I have exported the model with batch size 1 and batch size 4 to nsight systems:

1 batch inference test: https://drive.google.com/file/d/1achvISpSc1pvlV2RLfSNLxCLlRsZHcnT/view?usp=sharing
4 batch inference test: https://drive.google.com/file/d/1ZuHsO28LIlETNIcWk6lh7miv2Lovco9D/view?usp=sharing

Can anybody help, why is the speed regression happening?

this is the code that I use to build the engine:

` import pycuda.driver as cuda
import pycuda.autoinit

import numpy as np
import onnx
import tensorrt as trt
import torch

# Constants
ONNX_MODEL_PATH = 'new_full_explicit_batch4.onnx'
TENSORRT_ENGINE_PATH = 'new_full_explicit_batch4.engine'
MIN_BATCH_SIZE = 1
MAX_BATCH_SIZE = 16

# Set up the logger
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

# Create a TensorRT builder, runtime, and network
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
config = builder.create_builder_config()
parser = trt.OnnxParser(network, TRT_LOGGER)
parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM)

# Parse the ONNX model file
with open(ONNX_MODEL_PATH, 'rb') as model:
    if not parser.parse(model.read()):
        print('ERROR: Failed to parse the ONNX file.')
        for error in range(parser.num_errors):
            print(parser.get_error(error))
        exit(1)

# Define optimization profile for dynamic batch size
profile = builder.create_optimization_profile()
profile.set_shape('im_patches', (MIN_BATCH_SIZE, 3, 288, 288), (4, 3, 288, 288), (MAX_BATCH_SIZE, 3, 288, 288))
profile.set_shape('train_feat', (MIN_BATCH_SIZE, 256, 18, 18), (4, 256, 18, 18), (MAX_BATCH_SIZE, 256, 18, 18))
profile.set_shape('target_labels', (1, MIN_BATCH_SIZE, 18, 18), (1, 4, 18, 18), (1, MAX_BATCH_SIZE, 18, 18))
profile.set_shape('train_ltrb', (MIN_BATCH_SIZE, 4, 18, 18), (4, 4, 18, 18), (MAX_BATCH_SIZE, 4, 18, 18))
config.add_optimization_profile(profile)

# Build the engine
builder.max_batch_size = MAX_BATCH_SIZE
# config.max_workspace_size = 1 << 30  # 1GB of workspace size
engine = builder.build_engine(network, config)

# Save the engine
with open(TENSORRT_ENGINE_PATH, 'wb') as f:
    f.write(engine.serialize())

`

and I use polygraphy for inference:

`from polygraphy.backend.trt import EngineFromNetwork, NetworkFromOnnxPath, TrtRunner

with open("new_full_explicit_batch16.engine", "rb") as f:
    engine_data = f.read()
runtime1 = trt.Runtime(trt.Logger(trt.Logger.WARNING))
engine1 = runtime1.deserialize_cuda_engine(engine_data)

trt_engine1 = TrtRunner(engine1)
trt_engine1.activate()

input_data = {
   "im_patches": test_x_stack.cpu(),
   "train_feat": train_feat_stack.cpu(),
   "target_labels": target_labels_stack.cpu(),
   "train_ltrb": train_ltrb_stack.cpu(),
}
rez = trt_engine1.infer(input_data)
scores_raw = rez["scores_raw"].to("cuda")
bbox_preds = rez["bbox_preds"].to("cuda")

`

perhaps I need to convert the engine differently?
or maybe running inference with polygraphy isn't a good idea?
or perhaps the issue is elswhere?

anybody that has any idea please let me know

link to the onnx model: https://drive.google.com/file/d/1kxWaGbrk3M1slN1-v4C3524vtPeTMSCr/view?usp=sharing

Thank you

Environment

TensorRT Version: 8.6.1

NVIDIA GPU: GTX 1660 Ti

NVIDIA Driver Version: 546.01

CUDA Version: 12.1

CUDNN Version: 8.9.7

Operating System:

Python Version (if applicable): 3.10.13

PyTorch Version (if applicable): 2.1.2+cu121

@zerollzeng
Copy link
Collaborator

Normally it's cause by how your measure the perf, could you please try get a perf summary using trtexec? usage would be like trtexec --onnx=model.onnx, and check GPU compute time

@zerollzeng
Copy link
Collaborator

Also make sure you feed the some shape of input to trt and pytorch, also disable dynamic shape is much fair to trt.

@zerollzeng zerollzeng self-assigned this Feb 7, 2024
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Feb 7, 2024
@ninono12345
Copy link
Author

ninono12345 commented Feb 7, 2024

Hi @zerollzeng I apologize, a batch of 8 in regular pytorch takes about 160ms, I used time.time() to measure, but I placed it in the wrong place perhaps...

now for testing I have converted 2 tensorRT engines with static batches of 1 and 8

so now updated values:
batch of 1 pytorch 40ms
batch of 1 tensorrt 16ms
image

batch of 8 pytorch 160ms
batch of 8 tensorrt: 100ms
image

So I apologize, I will change the title, there is not a performance regression.
But still batch of 1 is 2.5 times quicker than regular pytorch, but a batch of 8 is not even 2 times faster than pytorch

TensorRT is advertised as being able to speed up inference by 5x, but I'm just not seeing results...

So I would like to keep this issue opened and perhaps you know what should be done to make inference faster? Is it the engine conversion part or is special inference code needed to utilize the GPU the best possible?
Thank you

@ninono12345
Copy link
Author

@zerollzeng I understand that this might be a very difficult task, but perhaps you can share your experience at which direction should I be looking at? Should I look into learning cuda programming? Shoud I learn to write custom plugins? How can I make this engine utilize more gpu resourses? Because In my logic, because batches are independent of each other, we should be able to run them in parralel at the speed of a single batch?

@ninono12345 ninono12345 changed the title performance regression, inference with multiple batches is slower than regular pytorch converting to TensorRT barely increases performance Feb 8, 2024
@ninono12345
Copy link
Author

@zerollzeng so now I am trying to run inference on tensorrt each batch in parallel. So I am using a static single batch engine, that has inference time of 16ms. In this example I try to inference 10 batches in parallel, so I create 10 execution contexts, 10 streams, for each context I allocate separately input and output memory, as well as do execution on all of them and synchronize only at the end.

This is the inference code:

` batch = 10

#assign random variables
im_patches = torch.randn(batch, 3, 288, 288)
train_feat = torch.randn(batch,256,18,18)
target_labels = torch.randn(1, batch, 18, 18)
train_ltrb = torch.randn(batch, 4, 18, 18)
inputss = [im_patches,train_feat,target_labels,train_ltrb]

scores_raw = torch.randn(1, 1, 18, 18)
bbox_preds = torch.randn(1, 1, 4, 18, 18)
outputs = [scores_raw, bbox_preds]

# our engine and streams
trt_runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
engine = load_engine(trt_runtime, TRT_ENGINE_PATH)
print(engine.num_bindings)
contexts = [engine.create_execution_context() for _ in range(batch)]
streams = [cuda.Stream() for _ in range(batch)]

# random sample data
for it in range(10):
    ins=[]
    outs=[]
    bindings = []
    input_mem=[]
    output_mem=[]
    for b in range(batch):
        inputs = []
        inputs.append(np.ascontiguousarray(im_patches[b].unsqueeze(0).numpy()))
        inputs.append(np.ascontiguousarray(train_feat[b].unsqueeze(0).numpy()))
        inputs.append(np.ascontiguousarray(target_labels[0][b].unsqueeze(0).unsqueeze(0).numpy()))
        inputs.append(np.ascontiguousarray(train_ltrb[b].unsqueeze(0).numpy()))
        
        input_mem.append([cuda.mem_alloc(t.nbytes) for t in inputss])
        output_mem.append([cuda.mem_alloc(t.nbytes) for t in outputs])
        bindings1 = [int(im) for im in input_mem[b]]
        bindings2 = [int(om) for om in output_mem[b]]
        bindings.append(bindings1 + bindings2)
        outs.append([cuda.pagelocked_empty_like(out.numpy()) for out in outputs])
        ins.append([cuda.pagelocked_empty_like(ii) for ii in inputs])
        for ii in range(len(inputs)):
            ins[b][ii] = inputs[ii]

# do 10 inferences
for it in range(10):
    pt = time.time()
    for b in range(batch):
        stream = streams[b]
        [cuda.memcpy_htod_async(input_mem[b][inp], ins[b][inp], stream) for inp in range(len(ins[b]))]
        
        contexts[b].execute_async_v2(bindings=bindings[b], stream_handle=stream.handle)
        
        [cuda.memcpy_dtoh_async(outs[b][out], output_mem[b][out], stream) for out in range(len(outs[b]))]
    
    ntt = time.time()
    print(ntt-pt)
    for stream in streams:
        stream.synchronize()

    nt = time.time()
    print(nt-pt)
    for out in outs[0]:
        print(out.shape)

`

If you try to reproduce you should get that the first part before synchronize completes quickly, but executing each synchronize takes the same as normal sequential execution. So I cannot understand, are enqueue_v2 and execution contexts still not running in parallel?

What do I do to make my engine faster, because when using it in a real time app tracking multiple objects really slows everything down.
What can be done to speed up in any way?

Or perhaps the inference needs to be coded in c++, or this will also give the same results?

onnx model used here: https://drive.google.com/file/d/1U_djLvIbDYv-Fxh60_coB7H9twPfcmP7/view?usp=sharing

code to convert from onnx to tensorrt:

` logger = trt.Logger(trt.Logger.WARNING)

builder = trt.Builder(logger)

network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM)
success = parser.parse_from_file("new_full_implicit_batch1_sanitized.onnx")

config = builder.create_builder_config()
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED

serialized_engine = builder.build_serialized_network(network, config)

with open("new_full_implicit_batch1_sanitized.engine", "wb") as f:
    f.write(serialized_engine)`

@zerollzeng
Copy link
Collaborator

But still batch of 1 is 2.5 times quicker than regular pytorch, but a batch of 8 is not even 2 times faster than pytorch

Usually mean GPU is fully utilized.

So I would like to keep this issue opened and perhaps you know what should be done to make inference faster? Is it the engine conversion part or is special inference code needed to utilize the GPU the best possible?

Try FP16/INT8, e.g. run with --int8 --fp16

@xillee
Copy link

xillee commented Mar 26, 2024

I try fp16, 0nly 7% of GPU is utilized. In the case, how to improve the performance?

@zerollzeng
Copy link
Collaborator

increase batch size or use multi thread?

@ttyio
Copy link
Collaborator

ttyio commented Jul 2, 2024

closing since no activity for more than 3 weeks, pls reopen if you still have question, thanks all!

@ttyio ttyio closed this as completed Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants