Skip to content

Commit

Permalink
Add layer_norm e2e test with stash_type (#399)
Browse files Browse the repository at this point in the history
`python run.py --mode=cl-onnx-iree -v --torchtolinalg -t
layer_norm_test`
  • Loading branch information
jinchen62 authored Nov 26, 2024
1 parent 8774ed3 commit 5a5092f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
34 changes: 34 additions & 0 deletions alt_e2eshark/onnx_tests/operators/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from onnx import TensorProto
from onnx.helper import make_node, make_tensor_value_info

from ..helper_classes import BuildAModel
from e2e_testing.registry import register_test

class LayerNormalizationModel(BuildAModel):
def construct_i_o_value_info(self):
X = make_tensor_value_info("X", TensorProto.FLOAT16, [2, 3, 5])
Scale = make_tensor_value_info("Scale", TensorProto.FLOAT16, [5])
B = make_tensor_value_info("B", TensorProto.FLOAT16, [5])
Y = make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 3, 5])
Mean = make_tensor_value_info("Mean", TensorProto.FLOAT, [2, 3, 1])
InvStdDev = make_tensor_value_info("InvStdDev", TensorProto.FLOAT, [2, 3, 1])
self.input_vi = [X, Scale, B]
self.output_vi = [Y, Mean, InvStdDev]

def construct_nodes(self):
layer_norm_node = make_node(
op_type="LayerNormalization",
inputs=["X", "Scale", "B"],
outputs=["Y", "Mean", "InvStdDev"],
axis=-1,
epsilon=1e-05,
stash_type=1,
)
self.node_list = [layer_norm_node]

register_test(LayerNormalizationModel, "layer_norm_test")
1 change: 1 addition & 0 deletions alt_e2eshark/onnx_tests/operators/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .conv import *
from .convtranspose import *
from .tfidf_vectorizer import *
from .layer_norm import *

0 comments on commit 5a5092f

Please sign in to comment.