Skip to content

Commit

Permalink
pnnx convert torch cross and t (#4896)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Aug 1, 2023
1 parent c6b191c commit 759d55d
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_clamp.cpp
pass_level2/torch_clone.cpp
pass_level2/torch_complex.cpp
pass_level2/torch_cross.cpp
pass_level2/torch_cumsum.cpp
pass_level2/torch_dequantize.cpp
pass_level2/torch_einsum.cpp
Expand Down Expand Up @@ -255,6 +256,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_std.cpp
pass_level2/torch_sum.cpp
pass_level2/torch_permute.cpp
pass_level2/torch_t.cpp
pass_level2/torch_tensor_split.cpp
pass_level2/torch_topk.cpp
pass_level2/torch_transpose.cpp
Expand Down
42 changes: 42 additions & 0 deletions tools/pnnx/src/pass_level2/torch_cross.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_cross : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 other
pnnx.Input input_2 0 1 dim
aten::cross op_0 3 1 input other dim out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.cross";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_cross, 20)

} // namespace pnnx
40 changes: 40 additions & 0 deletions tools/pnnx/src/pass_level2/torch_t.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_t : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input_0 0 1 input
aten::t op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.t";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_t, 20)

} // namespace pnnx
2 changes: 2 additions & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ pnnx_add_test(torch_cat)
pnnx_add_test(torch_chunk)
pnnx_add_test(torch_clone)
pnnx_add_test(torch_complex)
pnnx_add_test(torch_cross)
pnnx_add_test(torch_cumsum)
pnnx_add_test(torch_einsum)
pnnx_add_test(torch_eq)
Expand Down Expand Up @@ -224,6 +225,7 @@ pnnx_add_test(torch_split)
pnnx_add_test(torch_squeeze)
pnnx_add_test(torch_stack)
pnnx_add_test(torch_std)
pnnx_add_test(torch_t)
pnnx_add_test(torch_tensor_split)
pnnx_add_test(torch_topk)
pnnx_add_test(torch_transpose)
Expand Down
62 changes: 62 additions & 0 deletions tools/pnnx/tests/test_torch_cross.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
out0 = torch.cross(x, y)
out1 = torch.cross(x, y, dim=1)
out2 = torch.cross(z, w)
return out0, out1, out2

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 3)
y = torch.rand(3, 3)
z = torch.rand(5, 3)
w = torch.rand(5, 3)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_cross.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_cross.pt inputshape=[3,3],[3,3],[5,3],[5,3]")

# pnnx inference
import test_torch_cross_pnnx
b = test_torch_cross_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
59 changes: 59 additions & 0 deletions tools/pnnx/tests/test_torch_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
x = torch.t(x)
y = torch.t(y)
return x, y

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3)
y = torch.rand(5, 9)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_torch_t.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_t.pt inputshape=[3],[5,9]")

# pnnx inference
import test_torch_t_pnnx
b = test_torch_t_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

0 comments on commit 759d55d

Please sign in to comment.