From bfbf54d933c9e7260411a80470bcaf1d6ddd880b Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 1 Aug 2023 10:43:35 +0800 Subject: [PATCH 1/2] pnnx convert torch cross --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/pass_level2/torch_cross.cpp | 42 +++++++++++++++ tools/pnnx/tests/CMakeLists.txt | 1 + tools/pnnx/tests/test_torch_cross.py | 62 ++++++++++++++++++++++ 4 files changed, 106 insertions(+) create mode 100644 tools/pnnx/src/pass_level2/torch_cross.cpp create mode 100644 tools/pnnx/tests/test_torch_cross.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index c2bc0306f1e3..195dcf3a4af5 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -213,6 +213,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_clamp.cpp pass_level2/torch_clone.cpp pass_level2/torch_complex.cpp + pass_level2/torch_cross.cpp pass_level2/torch_cumsum.cpp pass_level2/torch_dequantize.cpp pass_level2/torch_einsum.cpp diff --git a/tools/pnnx/src/pass_level2/torch_cross.cpp b/tools/pnnx/src/pass_level2/torch_cross.cpp new file mode 100644 index 000000000000..16326391b803 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_cross.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_cross : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 other +pnnx.Input input_2 0 1 dim +aten::cross op_0 3 1 input other dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.cross"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_cross, 20) + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 626d549991df..6cbe932d8f3f 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -193,6 +193,7 @@ pnnx_add_test(torch_cat) pnnx_add_test(torch_chunk) pnnx_add_test(torch_clone) pnnx_add_test(torch_complex) +pnnx_add_test(torch_cross) pnnx_add_test(torch_cumsum) pnnx_add_test(torch_einsum) pnnx_add_test(torch_eq) diff --git a/tools/pnnx/tests/test_torch_cross.py b/tools/pnnx/tests/test_torch_cross.py new file mode 100644 index 000000000000..a3f4880786ee --- /dev/null +++ b/tools/pnnx/tests/test_torch_cross.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + out0 = torch.cross(x, y) + out1 = torch.cross(x, y, dim=1) + out2 = torch.cross(z, w) + return out0, out1, out2 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 3) + y = torch.rand(3, 3) + z = torch.rand(5, 3) + w = torch.rand(5, 3) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_cross.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_cross.pt inputshape=[3,3],[3,3],[5,3],[5,3]") + + # pnnx inference + import test_torch_cross_pnnx + b = test_torch_cross_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) From c185e4893a8002b0b5ad1afa4b41f9d7cde55b4e Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 1 Aug 2023 10:51:05 +0800 Subject: [PATCH 2/2] convert torch.t --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/pass_level2/torch_t.cpp | 40 +++++++++++++++++ tools/pnnx/tests/CMakeLists.txt | 1 + tools/pnnx/tests/test_torch_t.py | 59 ++++++++++++++++++++++++++ 4 files changed, 101 insertions(+) create mode 100644 tools/pnnx/src/pass_level2/torch_t.cpp create mode 100644 tools/pnnx/tests/test_torch_t.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 195dcf3a4af5..4c5a1e15cd63 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -256,6 +256,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_std.cpp pass_level2/torch_sum.cpp pass_level2/torch_permute.cpp + pass_level2/torch_t.cpp pass_level2/torch_tensor_split.cpp pass_level2/torch_topk.cpp pass_level2/torch_transpose.cpp diff --git a/tools/pnnx/src/pass_level2/torch_t.cpp b/tools/pnnx/src/pass_level2/torch_t.cpp new file mode 100644 index 000000000000..0dd48920aa9c --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_t.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_t : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +aten::t op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.t"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_t, 20) + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 6cbe932d8f3f..26ea2005d285 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -225,6 +225,7 @@ pnnx_add_test(torch_split) pnnx_add_test(torch_squeeze) pnnx_add_test(torch_stack) pnnx_add_test(torch_std) +pnnx_add_test(torch_t) pnnx_add_test(torch_tensor_split) pnnx_add_test(torch_topk) pnnx_add_test(torch_transpose) diff --git a/tools/pnnx/tests/test_torch_t.py b/tools/pnnx/tests/test_torch_t.py new file mode 100644 index 000000000000..a953ae03aabc --- /dev/null +++ b/tools/pnnx/tests/test_torch_t.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.t(x) + y = torch.t(y) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3) + y = torch.rand(5, 9) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_t.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_t.pt inputshape=[3],[5,9]") + + # pnnx inference + import test_torch_t_pnnx + b = test_torch_t_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)