Skip to content

Commit

Permalink
add support for tensorrt 8 (#1698)
Browse files Browse the repository at this point in the history
* add support for tensorrt 8

* address comments

* add test for fp16 and add warmup

(cherry picked from commit 5277872)

* fix format

* add validate_args flag

---------

Co-authored-by: Le Horizon <le-horizon@github>
  • Loading branch information
hnyu and Le Horizon authored Aug 29, 2024
1 parent d508d5b commit 6d56e66
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 51 deletions.
237 changes: 198 additions & 39 deletions alf/utils/tensorrt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
# ```bash
# pip install onnx>=1.16.2 protobuf==3.20.2
#
# pip install tensorrt>=10.0
# # https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing-pip
# pip install tensorrt==8.6.1
# # To install a different version of tensorrt, first make sure to ``rm -rf`` all dirs
# # under virtual env ``site-packages``` with the prefix ``tensorrt``.

# For cuda 11.x,
# pip install onnxruntime-gpu
Expand Down Expand Up @@ -165,8 +168,8 @@ def _strip_optimizers(module):
def recover_module_output(self, forward_output):
"""``forward_output`` is a direct return of ``self.forward()``.
"""
# remove the dummy output as the second one
forward_output = list(forward_output)[:-1]
# remove the dummy output as the last one
forward_output = forward_output[:-1]
output_nest = alf.nest.py_pack_sequence_as(self._output_params_spec,
forward_output)
output = dist_utils.params_to_distributions(output_nest,
Expand Down Expand Up @@ -214,7 +217,8 @@ def forward(self, *flat_all_args):

# We want to use ALF's flatten to avoid ONNX's defined flattening order
output_params = alf.nest.flatten(output_params)
return output_params, dummy_output
output_params.append(dummy_output)
return output_params


@alf.configurable(whitelist=['device'])
Expand All @@ -232,8 +236,12 @@ def __init__(self,
NOTE: if ``tensorrt`` lib is not installed, this backend will fall back
to use CUDA. If GPU is not available, this backend will fall back to CPU.
So the class name might not be accurate. But since its main purpose is
using tensorRT for inference, we keep the name as it is.
To exclude certain providers, set the env var ``ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS``.
For example,
.. code-block:: bash
ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS='TensorrtExecutionProvider,CPUExecutionProvider'
This class is mainly responsible for:
Expand Down Expand Up @@ -303,7 +311,6 @@ def __call__(self, *args, **kwargs):
return self._onnx_wrapper.recover_module_output(outputs)


@alf.configurable(whitelist=['memory_limit_gb', 'fp16'])
class TensorRTEngine(object):
def __init__(self,
module: torch.nn.Module,
Expand All @@ -313,7 +320,8 @@ def __init__(self,
memory_limit_gb: float = 1.,
fp16: bool = False,
example_args: Tuple[Any] = (),
example_kwargs: Dict[str, Any] = {}):
example_kwargs: Dict[str, Any] = {},
validate_args: bool = False):
"""Class for converting a torch.nn.Module to TensorRT engine for fast
inference, via ONNX model as the intermediate representation.
Expand All @@ -336,17 +344,27 @@ def __init__(self,
a tuple of args.
example_kwargs: The example kwargs to be used for ``method``. Should
be a dict of kwargs.
validate_args: if True, every call of the engine will first check
if the args are consistent with the example args that were used
to build the engine. If None, useful debugging info will be
printed. Default to False. Use this flag if you have some memcpy
issue for an input.
"""
assert torch.cuda.is_available(
), 'This engine can only be used on GPU!'

example_args = _dtype_conversions(example_args)
example_kwargs = _dtype_conversions(example_kwargs)
self._validate_args = validate_args
self._example_args = example_args
self._example_kwargs = example_kwargs
self._onnx_wrapper = _OnnxWrapper(module, method, example_args,
example_kwargs)
flat_all_args = tuple(alf.nest.flatten([example_args, example_kwargs]))
self._inputs = flat_all_args
self._outputs = alf.nest.flatten(self._onnx_wrapper.example_output)
input_names = [f'input-{i}' for i in range(len(self._inputs))]
output_names = [f'output-{i}' for i in range(len(self._outputs))]

if onnx_file is None:
onnx_io = io.BytesIO()
Expand All @@ -356,6 +374,8 @@ def __init__(self,
torch.onnx.export(
self._onnx_wrapper,
args=self._inputs,
input_names=input_names,
output_names=output_names,
f=onnx_io,
# Don't modify the version easily! Other versions might
# have weird errors.
Expand All @@ -368,37 +388,57 @@ def __init__(self,
with open(onnx_io, 'rb') as f:
model_content = f.read()

engine = self._build_engine(model_content, fp16, memory_limit_gb)
self._prepare_io(engine)
self._engine = engine

def _build_engine(self, model_content, fp16, memory_limit_gb):
# Create a TensorRT logger
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# Create a builder and network
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network()
parser = trt.OnnxParser(network, TRT_LOGGER)
parser.parse(model_content)

# Create a builder configuration
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
int((1 << 30) * memory_limit_gb))
if fp16:
config.set_flag(trt.BuilderFlag.FP16)

# Build the engine
serialized_engine = builder.build_serialized_network(network, config)
# Create a runtime to deserialize the engine
runtime = trt.Runtime(TRT_LOGGER)
# Deserialize the engine
self._engine = runtime.deserialize_cuda_engine(serialized_engine)
self._prepare_io()
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network() as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
parser.parse(model_content)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
int((1 << 30) * memory_limit_gb))
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
# Build the engine
serialized_engine = builder.build_serialized_network(
network, config)
# Create a runtime to deserialize the engine
runtime = trt.Runtime(TRT_LOGGER)
# Deserialize the engine
return runtime.deserialize_cuda_engine(serialized_engine)

@staticmethod
def _get_bytes(tensor):
"""Get a tensor's size in bytes.
"""
return tensor.element_size() * tensor.nelement()

def _prepare_io(self):
self._context = self._engine.create_execution_context()
def _check_args(self, args, kwargs):
alf.nest.assert_same_structure(args, self._example_args)
alf.nest.assert_same_structure(kwargs, self._example_kwargs)

def _check_tensor_shape_and_dtype(path, x, y):
if (not isinstance(x, torch.Tensor)
or not isinstance(y, torch.Tensor)):
assert type(x) == type(y), (
f"'{path}' has different types: {type(x)} vs {type(y)}")
return
assert x.shape == y.shape, (
f"'{path}' has different shapes: {x.shape} vs {y.shape}")
assert x.dtype == y.dtype, (
f"'{path}' has different dtypes: {x.dtype} vs {y.dtype}")

alf.nest.py_map_structure_with_path(
_check_tensor_shape_and_dtype, (args, kwargs),
(self._example_args, self._example_kwargs))

def _prepare_io(self, engine):
self._context = engine.create_execution_context()

# allocate device memory (bytes)
self._input_mem = [
Expand All @@ -411,10 +451,9 @@ def _prepare_io(self):
# Set the IO tensor addresses
bindings = list(map(int, self._input_mem)) + list(
map(int, self._output_mem))
for i in range(self._engine.num_io_tensors):
for i in range(engine.num_io_tensors):
self._context.set_tensor_address(
self._engine.get_tensor_name(i), bindings[i])

engine.get_tensor_name(i), bindings[i])
# create stream
self._stream = cuda.Stream()

Expand All @@ -424,6 +463,9 @@ def __call__(self, *args, **kwargs):
The arguments must be GPU tensors, otherwise invalid mem addresses will
be reported.
"""
if self._validate_args:
self._check_args(args, kwargs)

flat_all_args = _dtype_conversions(alf.nest.flatten([args, kwargs]))

for im, i in zip(self._input_mem, flat_all_args):
Expand All @@ -432,8 +474,8 @@ def __call__(self, *args, **kwargs):
self._get_bytes(i), self._stream)

# For some reason, we have to manually synchronize the stream here before
# executing the engine. Otherwise the inference will be much slower. Probably
# a pycuda bug because in theory this synchronization is not needed.
# executing the engine. Otherwise the inference will be much slower sometimes.
# Probably a pycuda bug because in theory this synchronization is not needed.
self._stream.synchronize()

self._context.execute_async_v3(stream_handle=self._stream.handle)
Expand All @@ -450,6 +492,87 @@ def __call__(self, *args, **kwargs):
return self._onnx_wrapper.recover_module_output(outputs)


class TensorRT8Engine(TensorRTEngine):
"""A big trouble of TensorRT 8 is that its input/output args order might not be
consistent with that of the ONNX model! So we need to manually keep track of
the correspondence when memcopying between host/device.
Also there is a slight API difference when creating the engine.
"""

def _build_engine(self, model_content, fp16, memory_limit_gb):
# Create a TensorRT logger
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network(1 << int(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
# Create a builder and network
parser.parse(model_content)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
int((1 << 30) * memory_limit_gb))
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
return builder.build_engine(network, config)

def _prepare_io(self, engine):
self._context = engine.create_execution_context()
self._input_mem = []
self._input_idx = []
self._output_mem = []
self._output_idx = []
self._bindings = []
# TRT8: This order might be different from the order of the onnx model!!
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
idx = int(name.split('-')[1])
if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
mem = cuda.mem_alloc(self._get_bytes(self._inputs[idx]))
self._input_mem.append(mem)
self._input_idx.append(idx)
else:
mem = cuda.mem_alloc(self._get_bytes(self._outputs[idx]))
self._output_mem.append(mem)
self._output_idx.append(idx)
self._bindings.append(int(mem))
self._stream = cuda.Stream()

def __call__(self, *args, **kwargs):
if self._validate_args:
self._check_args(args, kwargs)

flat_all_args = _dtype_conversions(alf.nest.flatten([args, kwargs]))

for i in range(len(flat_all_args)):
im = self._input_mem[i]
arg = flat_all_args[self._input_idx[i]]
cuda.memcpy_dtod_async(im,
arg.contiguous().data_ptr(),
self._get_bytes(arg), self._stream)

# For some reason, we have to manually synchronize the stream here before
# executing the engine. Otherwise the inference will be much slower sometimes.
# Probably a pycuda bug because in theory this synchronization is not needed.
self._stream.synchronize()
self._context.execute_async_v2(
bindings=self._bindings, stream_handle=self._stream.handle)

outputs = [
torch.empty_like(o, memory_format=torch.contiguous_format)
for o in self._outputs
]

for i in range(len(outputs)):
om = self._output_mem[i]
out = outputs[self._output_idx[i]]
cuda.memcpy_dtod_async(out.data_ptr(), om, self._get_bytes(out),
self._stream)

self._stream.synchronize()
return self._onnx_wrapper.recover_module_output(outputs)


def compile_for_inference_if(cond: bool = True,
engine_class: Callable = TensorRTEngine):
"""A decorator to compile a method as a onnxruntime/tensorRT engine for inference,
Expand Down Expand Up @@ -509,9 +632,40 @@ def wrapped(module_to_wrap, *args, **kwargs):
_compiled_methods = {}


def compile_method(module,
method_name,
engine_class: Callable = TensorRTEngine):
@alf.configurable
def get_tensorrt_engine_class(memory_limit_gb: float = 1.,
fp16: bool = False,
validate_args: bool = False):
"""Get the proper tensorrt engine class depending on the available ``tensorrt``
version.
Currently we only support tensorrt 8 and 10.
Args:
memory_limit_gb: The memory limit in GBs for tensorRT for inference.
fp16: If True, the model will do inference in fp16.
validate_args: if True, every call of the engine will first check
if the args are consistent with the example args that were used
to build the engine. If None, useful debugging info will be
printed. Default to False. Use this flag if you have some memcpy
issue for an input.
"""
assert is_tensorrt_available()
trt_major_ver = trt.__version__.split('.')[0]
# On some edge device like Jetson, only tensorrt 8 is supported
if trt_major_ver == '8':
cls = TensorRT8Engine
else:
assert trt_major_ver == '10'
cls = TensorRTEngine
return functools.partial(
cls,
memory_limit_gb=memory_limit_gb,
fp16=fp16,
validate_args=validate_args)


def compile_method(module, method_name, engine_class: Callable = None):
"""Convert a module method to use OnnxRuntime or TensorRT inference on the fly.
For example,
Expand Down Expand Up @@ -560,8 +714,13 @@ def compile_method(module,
Args:
module: a torch.nn.Module
method_name: the method name of the module
engine_class: should be either ``TensorRTEngine`` or ``OnnxRuntimeEngine``
engine_class: should be either ``TensorRTEngine``, ``TensorRT8Engine``,
or ``OnnxRuntimeEngine``. If None, will use ``get_tensorrt_engine_class()``
to choose a tensorrt engine.
"""
if engine_class is None:
engine_class = get_tensorrt_engine_class()

global _compiled_methods
key = (module, method_name)
# Here we check if a previous ``compile_method`` has already been called
Expand Down
Loading

0 comments on commit 6d56e66

Please sign in to comment.