Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Bernoulli operator implementation might be wrong (value mismatch in e2e testing when using function expansion) #3527

Open
andfau-amd opened this issue Jul 8, 2024 · 0 comments

Comments

@andfau-amd
Copy link
Contributor

andfau-amd commented Jul 8, 2024

Issue I discovered while working on #3384.

Tested on main, commit 0b46d11.

Normally the ONNX Bernoulli operator gets imported as torch.operator "onnx.Bernoulli", but since ONNX provides a function for this operator, it's possible to make the importer pre-expand it (using the code added in #3409).

If we apply this patch

diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py
index 9fe29212..396096da 100644
--- a/python/torch_mlir/extras/onnx_importer.py
+++ b/python/torch_mlir/extras/onnx_importer.py
@@ -104,6 +104,7 @@ class Config:
             # Default domain (ONNX built-in ops)
             "": {
                 "MeanVarianceNormalization",
+                "Bernoulli",
             }
         }
     )

then the new importer output becomes

module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %none = torch.constant.none
    %0 = call @"('Bernoulli', '', 20, [tensor_type {\0A  elem_type: 11\0A  shape {\0A    dim {\0A      dim_param: \22dim_0_0\22\0A    }\0A    dim {\0A      dim_param: \22dim_0_1\22\0A    }\0A  }\0A}\0A], [tensor_type {\0A  elem_type: 11\0A  shape {\0A    dim {\0A      dim_param: \22dim_0_0\22\0A    }\0A    dim {\0A      dim_param: \22dim_0_1\22\0A    }\0A  }\0A}\0A], input: \22input_0\22\0Aoutput: \221\22\0Aname: \22/Bernoulli\22\0Aop_type: \22Bernoulli\22\0A)"(%arg0) : (!torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],f64>
    return %0 : !torch.vtensor<[?,?],f64>
  }
  func.func private @"('Bernoulli', '', 20, [tensor_type {\0A  elem_type: 11\0A  shape {\0A    dim {\0A      dim_param: \22dim_0_0\22\0A    }\0A    dim {\0A      dim_param: \22dim_0_1\22\0A    }\0A  }\0A}\0A], [tensor_type {\0A  elem_type: 11\0A  shape {\0A    dim {\0A      dim_param: \22dim_0_0\22\0A    }\0A    dim {\0A      dim_param: \22dim_0_1\22\0A    }\0A  }\0A}\0A], input: \22input_0\22\0Aoutput: \221\22\0Aname: \22/Bernoulli\22\0Aop_type: \22Bernoulli\22\0A)"(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.RandomUniformLike"(%arg0) {torch.onnx.dtype = 11 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32} : (!torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],f64>
    %1 = torch.operator "onnx.Greater"(%0, %arg0) : (!torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],i1>
    %2 = torch.operator "onnx.Cast"(%1) {torch.onnx.to = 11 : si64} : (!torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],f64>
    return %2 : !torch.vtensor<[?,?],f64>
  }
}

and the e2e tests for Bernoulli start failing:

$ python -m e2e_testing.main -f Bernoulli -c onnx

[...]

XFAIL - "BernoulliFloatModule_basic"
XFAIL - "BernoulliModule_basic"
FAIL - "BernoulliOnesModule_basic"
XFAIL - "BernoulliPModule_basic"
XFAIL - "BernoulliTensorModule_basic"
FAIL - "BernoulliZerosModule_basic"

Unexpected outcome summary: (onnx)

****** Failed tests - 2 tests
    FAIL - "BernoulliOnesModule_basic"
    FAIL - "BernoulliZerosModule_basic"

Summary:
    Failed: 2
    Expectedly Failed: 4

When I investigated why this happens, it seems to be that the ONNX function interprets the input to the operator (let's call it p) in the opposite way to what these tests expect. p is always in [0,1], but the ONNX function behaves like (1-p) was passed. So, where an all-ones result is expected, it gets all zeroes, and vice-versa.

Looking at the importer output above, we can see ONNX's definition is very simple: generate random numbers (each is in the range [0,1] I believe), then elementwise compare against p, with the comparison result (false or true) being casted to an integer (0 or 1). To get the "expected" behavior, greater-than would have to be replaced with a different comparison (perhaps less-than-or-equal).

To me, this surely indicates a bug, but I'm not sure which implementation is "wrong" and which is "right".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant