Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Dynamic Shape Inference for TTS Model with HiFi-GAN Vocoder #4230

Open
UmerrAhsan opened this issue Oct 30, 2024 · 1 comment
Open
Labels
triaged Issue has been triaged by maintainers

Comments

@UmerrAhsan
Copy link

UmerrAhsan commented Oct 30, 2024

Description:

I converted the decoder of a TTS model (with HiFi-GAN vocoder) from PyTorch to ONNX and then to an engine format. During inference, both input and output shapes are dynamic, changing with each call. Currently, I’m allocating and deallocating memory on each inference run, but I’m unsure if this is the best approach.

System Details
TensorRT: 10.5.0
CUDA: 12.1
OS: Ubuntu 20.04
GPU: A100
Problem
Dynamic Shape Handling: Is my approach of allocating/deallocating at each inference and overall code of inference is correct?
Output Shape: My code does not correctly handle dynamic output shapes and it's always give the same size output (1,).

onnx to engine conversion:

import tensorrt as trt
import numpy as np
import pycuda.autoinit

Convert ONNX to TensorRT engine
`import tensorrt as trt
import numpy as np
import pycuda.autoinit

# Convert ONNX to TensorRT engine
def build_engine(onnx_file_path, min_shape, opt_shape, max_shape):
    logger = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(logger)
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) #we have enabled the explicit Batch
    network = builder.create_network(EXPLICIT_BATCH) 
    parser = trt.OnnxParser(network, logger)

    success = parser.parse_from_file(onnx_file_path)
    for idx in range(parser.num_errors):
        print(parser.get_error(idx))

    if not success:
        pass 
    
    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1 MiB

    
    # Set dynamic shapes
    profile = builder.create_optimization_profile()
    profile.set_shape("asr", min_shape['asr'], opt_shape['asr'], max_shape['asr'])
    profile.set_shape("f0", min_shape['f0'], opt_shape['f0'], max_shape['f0'])
    profile.set_shape("n", min_shape['n'], opt_shape['n'], max_shape['n'])
    profile.set_shape("ref", min_shape['ref'], opt_shape['ref'], max_shape['ref'])
    config.add_optimization_profile(profile)
    config.default_device_type = trt.DeviceType.GPU
    
    #engine = builder.build_engine(network, config)
    serialized_engine = builder.build_serialized_network(network, config)

    if serialized_engine is None:
        print("Failed to build engine")
        return None

    with open("sample.engine", "wb") as f:
        f.write(serialized_engine)

    return serialized_engine

# Main execution
def main():
    onnx_file_path = "new_decoder.onnx"
    
    # Define shapes
    hidden_dim, style_dim = 512, 128
    min_time_dim, max_time_dim, opt_time_dim = 28, 1106, 56
    
    min_shape = {
        'asr': (1, hidden_dim, min_time_dim),
        'f0': (1, min_time_dim * 2),
        'n': (1, min_time_dim * 2),
        'ref': (1, style_dim)
    }
    opt_shape = {
        'asr': (1, hidden_dim, opt_time_dim),
        'f0': (1, opt_time_dim * 2),
        'n': (1, opt_time_dim * 2),
        'ref': (1, style_dim)
    }
    max_shape = {
        'asr': (1, hidden_dim, max_time_dim),
        'f0': (1, max_time_dim * 2),
        'n': (1, max_time_dim * 2),
        'ref': (1, style_dim)
    }
    
    # Build TensorRT engine
    engine = build_engine(onnx_file_path, min_shape, opt_shape, max_shape)
    
    if engine is None:
        print("Failed to build engine")
        return

if __name__ == "__main__":
    main()`

inference.py:

`import numpy as np
import torch
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit  # initializes CUDA driver and context
import time

class HostDeviceMem(object):
        '''
        Helper class to record host-device memory pointer pairs
        '''
        def __init__(self, host_mem, device_mem):
            self.host = host_mem
            self.device = device_mem


# Define constants for input dimensions
hidden_dim, style_dim = 512, 128
min_time_dim, max_time_dim, opt_time_dim = 28, 1106, 56

# Define dynamic shapes for the inputs
min_shape = {
    'asr': (1, hidden_dim, min_time_dim),
    'f0': (1, min_time_dim * 2),
    'n': (1, min_time_dim * 2),
    'ref': (1, style_dim)
}
opt_shape = {
    'asr': (1, hidden_dim, opt_time_dim),
    'f0': (1, opt_time_dim * 2),
    'n': (1, opt_time_dim * 2),
    'ref': (1, style_dim)
}
max_shape = {
    'asr': (1, hidden_dim, max_time_dim),
    'f0': (1, max_time_dim * 2),
    'n': (1, max_time_dim * 2),
    'ref': (1, style_dim)
}

# # Create random example inputs matching optimal shape
# input_asr = torch.randn(max_shape['asr']).numpy()
# input_f0 = torch.randn(max_shape['f0']).numpy()
# input_n = torch.randn(max_shape['n']).numpy()
# input_ref = torch.randn(max_shape['ref']).numpy()

asr = torch.randn(opt_shape['asr']).numpy()
f0 = torch.randn(opt_shape['f0']).numpy()
n = torch.randn(opt_shape['n']).numpy()
ref = torch.randn(opt_shape['ref']).numpy()

logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)

with open("sample.engine", "rb") as f:
    serialized_engine = f.read()

engine = runtime.deserialize_cuda_engine(serialized_engine)
context = engine.create_execution_context()
inputs, outputs, bindings = [], [], []
stream = cuda.Stream()

for i in range(engine.num_io_tensors):
    tensor_name = engine.get_tensor_name(i)
    print("Tensor:", tensor_name, "Shape:", engine.get_tensor_shape(tensor_name))


def infer(asr, f0, n, ref):
    # Actual shapes of the inputs
    input_shapes = [asr.shape, f0.shape, n.shape, ref.shape]

    inputs = []
    outputs = []
    bindings = []
    context.set_input_shape("asr", asr.shape)
    context.set_input_shape("f0", f0.shape)
    context.set_input_shape("n", n.shape)
    context.set_input_shape("ref", ref.shape)

    for i in range(engine.num_io_tensors):
        tensor_name = engine.get_tensor_name(i)
        dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))

        # Check if it's an input or output tensor
        if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
            shape = input_shapes.pop(0)  # Get the shape from the input shapes
            size = trt.volume(shape)
            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            inputs.append(HostDeviceMem(host_mem, device_mem))
            bindings.append(int(device_mem))
            np.copyto(inputs[-1].host, locals()[tensor_name].ravel())  # Assuming your inputs are named like this
        else:
            temp_shape = (1,)  # Placeholder, adjust if necessary
            size = trt.volume(temp_shape)
            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            outputs.append(HostDeviceMem(host_mem, device_mem))
            bindings.append(int(device_mem))

    # Transfer inputs to device
    for i in range(len(inputs)):
        cuda.memcpy_htod_async(inputs[i].device, inputs[i].host, stream)

    # Set tensor address for each input/output
    for i in range(engine.num_io_tensors):
        context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
 
    # Transfer predictions back
    cuda.memcpy_dtoh_async(outputs[0].host, outputs[0].device, stream)

    # Synchronize the stream
    stream.synchronize()

    return outputs[0].host

def cleanup():
    for input_mem in inputs:
        input_mem.device.free()  # Free device memory for each input
    for output_mem in outputs:
        output_mem.device.free()  # Free device memory for each output



# Run inference
start_time = time.time()
output = infer(asr, f0, n, ref)
end_time = time.time()
# print time in milliseconds
print("Time taken:", (end_time - start_time) * 1000, "ms")
print("Output shape:", output.shape)

# Clean up memory after inference
cleanup()

`

Output:
Tensor: asr Shape: (1, 512, -1)
Tensor: f0 Shape: (1, -1)
Tensor: n Shape: (1, -1)
Tensor: ref Shape: (1, 128)
Tensor: output Shape: (1, 1, -1) # all these tensors attached showing correct shapes with -1 indicating dynamic dimension.
Time taken: 1.0862350463867188 ms
Output shape: (1,) # output shape is always this, needed to fix this

Are there optimized methods for managing dynamic shapes more efficiently for this setup? Any help or guidance would be greatly appreciated!

@AntixK
Copy link

AntixK commented Nov 3, 2024

Hi,

I could resolve the issue with for the output.

temp_shape = context.get_tensor_shape(tensor_name)

@poweiw poweiw added question Further information is requested triaged Issue has been triaged by maintainers and removed question Further information is requested labels Nov 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants