Skip to content

Commit

Permalink
add file cache for tensorrt engine (#1699)
Browse files Browse the repository at this point in the history
* add disk cache for tensorrt engine

* improve docstring
  • Loading branch information
hnyu authored Sep 13, 2024
1 parent a6e8fd9 commit 0283006
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 24 deletions.
97 changes: 73 additions & 24 deletions alf/utils/tensorrt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ def __init__(self,
fp16: bool = False,
example_args: Tuple[Any] = (),
example_kwargs: Dict[str, Any] = {},
engine_file: Optional[str] = None,
force_build_engine: bool = False,
validate_args: bool = False):
"""Class for converting a torch.nn.Module to TensorRT engine for fast
inference, via ONNX model as the intermediate representation.
Expand All @@ -336,6 +338,8 @@ def __init__(self,
method: A method in ``module`` to be converted.
onnx_file: The path to the onnx model. If None, no external file will
be created. Instead, the onnx model will be created in memory.
Note that an onnx model is exported only when the engine is not
loaded from a file.
onnx_verbose: If True, the onnx model exporting process will output
verbose information.
memory_limit_gb: The memory limit in GBs for tensorRT for inference.
Expand All @@ -344,6 +348,17 @@ def __init__(self,
a tuple of args.
example_kwargs: The example kwargs to be used for ``method``. Should
be a dict of kwargs.
engine_file: if provided, this class will first check if such a file
exists. If so, it will load the engine from the file and skip the
build process. If not, the built engine will be saved to this file.
When a valid file is provided, it is the *user's responsibility*
to ensure that the loaded engine will work correctly in the current
context; no check will be performed by this class.
NOTE: This option is only intended as an inter-process cache, e.g.,
storing the engine to disk and later loading it for reuse.
force_build_engine: if True, the engine will always be built and
overwrite the engine file (if exists). This flag is only used
when a valid engine file is provided.
validate_args: if True, every call of the engine will first check
if the args are consistent with the example args that were used
to build the engine. If None, useful debugging info will be
Expand All @@ -363,35 +378,57 @@ def __init__(self,
flat_all_args = tuple(alf.nest.flatten([example_args, example_kwargs]))
self._inputs = flat_all_args
self._outputs = alf.nest.flatten(self._onnx_wrapper.example_output)
input_names = [f'input-{i}' for i in range(len(self._inputs))]
output_names = [f'output-{i}' for i in range(len(self._outputs))]

if onnx_file is None:
onnx_io = io.BytesIO()
else:
onnx_io = onnx_file
# 'args' must be a tuple of tensors
torch.onnx.export(
self._onnx_wrapper,
args=self._inputs,
input_names=input_names,
output_names=output_names,
f=onnx_io,
# Don't modify the version easily! Other versions might
# have weird errors.
opset_version=12,
verbose=onnx_verbose)
if isinstance(onnx_io, io.BytesIO):
onnx_io.seek(0)
model_content = onnx_io.getvalue()
else:
with open(onnx_io, 'rb') as f:
model_content = f.read()
engine = self._load_or_build_engine(onnx_file, onnx_verbose,
engine_file, force_build_engine,
fp16, memory_limit_gb)

engine = self._build_engine(model_content, fp16, memory_limit_gb)
self._prepare_io(engine)
self._engine = engine

def _load_or_build_engine(self, onnx_file: Optional[str],
onnx_verbose: bool, engine_file: Optional[str],
force_build_engine: bool, fp16: bool,
memory_limit_gb: float):
if (not force_build_engine and engine_file is not None
and os.path.isfile(engine_file)):
runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
with open(engine_file, "rb") as f:
engine_data = f.read()
engine = runtime.deserialize_cuda_engine(engine_data)
else:
if onnx_file is None:
onnx_io = io.BytesIO()
else:
onnx_io = onnx_file
input_names = [f'input-{i}' for i in range(len(self._inputs))]
output_names = [f'output-{i}' for i in range(len(self._outputs))]
# 'args' must be a tuple of tensors
torch.onnx.export(
self._onnx_wrapper,
args=self._inputs,
input_names=input_names,
output_names=output_names,
f=onnx_io,
# Don't modify the version easily! Other versions might
# have weird errors.
opset_version=12,
verbose=onnx_verbose)
if isinstance(onnx_io, io.BytesIO):
onnx_io.seek(0)
model_content = onnx_io.getvalue()
else:
with open(onnx_io, 'rb') as f:
model_content = f.read()
engine = self._build_engine(model_content, fp16, memory_limit_gb)

if engine_file is not None:
# Write the engine to file
with open(engine_file, "wb") as f:
f.write(engine.serialize())

return engine

def _build_engine(self, model_content, fp16, memory_limit_gb):
# Create a TensorRT logger
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
Expand Down Expand Up @@ -635,6 +672,8 @@ def wrapped(module_to_wrap, *args, **kwargs):
@alf.configurable
def get_tensorrt_engine_class(memory_limit_gb: float = 1.,
fp16: bool = False,
engine_file: Optional[str] = None,
force_build_engine: bool = False,
validate_args: bool = False):
"""Get the proper tensorrt engine class depending on the available ``tensorrt``
version.
Expand All @@ -644,6 +683,14 @@ def get_tensorrt_engine_class(memory_limit_gb: float = 1.,
Args:
memory_limit_gb: The memory limit in GBs for tensorRT for inference.
fp16: If True, the model will do inference in fp16.
engine_file: if provided, this class will first check if such a file
exists. If so, it will load the engine from the file and skip the
build process. If not, the built engine will be saved to this file.
When a file is provided, it is the *user's responsibility* to ensure
that the loaded engine will work correctly in the current context.
No check will be performed by this class.
force_build_engine: if True, the engine will always be built and
overwrite the engine file (if exists).
validate_args: if True, every call of the engine will first check
if the args are consistent with the example args that were used
to build the engine. If None, useful debugging info will be
Expand All @@ -662,6 +709,8 @@ def get_tensorrt_engine_class(memory_limit_gb: float = 1.,
cls,
memory_limit_gb=memory_limit_gb,
fp16=fp16,
engine_file=engine_file,
force_build_engine=force_build_engine,
validate_args=validate_args)


Expand Down
43 changes: 43 additions & 0 deletions alf/utils/tensorrt_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from absl.testing import parameterized
from functools import partial
import time
import tempfile
import torch
import torchvision.models as models
import unittest
Expand Down Expand Up @@ -147,6 +148,48 @@ def test_tensorrt_engine(self):

self.assertTensorClose(trt_alg_step.output, alg_step.output)

@unittest.skipIf(not is_tensorrt_available(), "tensorrt is unavailable")
def test_tensorrt_engine_cache(self):
engine_file = tempfile.mktemp(suffix='.trt')
alg, timestep, state = create_sac_and_inputs()
compile_method(
alg, 'predict_step',
get_tensorrt_engine_class(
validate_args=True, engine_file=engine_file))
start_time = time.time()
alg.predict_step(timestep, state=state) # build engine
self.assertGreater(time.time() - start_time,
1) # takes more than 1 second

alg, timestep, state = create_sac_and_inputs()
# Now if we compile again with engine file, the engine should be directly
# loaded from disk, even though the alg has been recreated
compile_method(
alg, 'predict_step',
get_tensorrt_engine_class(
validate_args=True, engine_file=engine_file))
start_time = time.time()
alg.predict_step(timestep, state=state) # load engine
self.assertLess(time.time() - start_time, 0.1)

alg, timestep, state = create_sac_and_inputs()
# Now we compile again and force building the engine
compile_method(
alg, 'predict_step',
get_tensorrt_engine_class(
validate_args=True,
engine_file=engine_file,
force_build_engine=True))
start_time = time.time()
alg.predict_step(timestep, state=state) # build engine
self.assertGreater(time.time() - start_time, 1)

alg, timestep, state = create_sac_and_inputs()
compile_method(alg, 'predict_step')
start_time = time.time()
alg.predict_step(timestep, state=state) # build engine
self.assertGreater(time.time() - start_time, 1)

@unittest.skipIf(not is_tensorrt_available()
and not is_onnxruntime_available(),
"tensorrt and onnxruntime are unavailable")
Expand Down

0 comments on commit 0283006

Please sign in to comment.