Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROOT CAUSE? InstanceNorm3D with tensorRT result consistency with nn.Instance_norm3d #938

Open
Yuxiang1990 opened this issue Jun 21, 2024 · 0 comments

Comments

@Yuxiang1990
Copy link

Hi,

  • with tensorrt_converter for instance norm 3D, the result was abnormal
  • Although verification by pytest passes, instance norm with the default initialization. random initialization will be failed.

https://github.com/NVIDIA-AI-IOT/torch2trt/blob/master/torch2trt/converters/native_converters.py?plain=1#L902


def _add_scale_1d2d3d(network, x_trt, mode, offset, scale, power):
    ndim = len(x_trt.shape)
    y_trt = x_trt
    # shape to 2D
    if ndim != 4:
        layer = network.add_shuffle(y_trt)
        layer.reshape_dims = (x_trt.shape[0], x_trt.shape[1], x_trt.shape[2], -1)  # NCH -> NCHW
        y_trt = layer.get_output(0)

    y_trt = network.add_scale(y_trt, mode, offset, scale, power).get_output(0)

    # shape to original dimension
    if ndim != 4:    
        # need to be modify `layer = network.add_shuffle(y_trt.get_output(0)) `
        layer = network.add_shuffle(layer.get_output(0))  
        layer.reshape_dims = tuple(x_trt.shape)
        y_trt = layer.get_output(0)
    
    return y_trt
  • layer = network.add_shuffle(layer.get_output(0)) -> layer = network.add_shuffle(y_trt.get_output(0))
@Yuxiang1990 Yuxiang1990 changed the title InstanceNorm3D with tensorRT result consistency with nn.Instance_norm3d, root cause?? ROOT CAUSE? InstanceNorm3D with tensorRT result consistency with nn.Instance_norm3d Jun 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant