Skip to content

Commit

Permalink
pnnx convert onnx expand/permute/repeat/reshape/select/slice/cat/ceil…
Browse files Browse the repository at this point in the history
…/chunk/flatten/floor/maximum/minimum/split/squeeze/stack/transpose/unbind/unsqueeze (Tencent#5583)
  • Loading branch information
nihui authored Jul 15, 2024
1 parent e7cae68 commit 569617f
Show file tree
Hide file tree
Showing 28 changed files with 1,399 additions and 61 deletions.
48 changes: 48 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,52 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_1, 20)

class Tensor_expand_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Expand op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.expand";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.shape") == captured_params.end())
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("op_0.shape").type == 5)
{
op->params["shape"] = captured_params.at("op_0.shape");
}
else // if (captured_params.at("op_0.shape").type == 2)
{
op->params["shape"] = std::vector<int>{captured_params.at("op_0.shape").i};
}

// onnx set expand shape 1 for not changing the size of that dimension while torch uses -1
for (size_t i = 0; i < op->params["shape"].ai.size(); i++)
{
if (op->params["shape"].ai[i] == 1)
op->params["shape"].ai[i] = -1;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_onnx, 20)

} // namespace pnnx
77 changes: 25 additions & 52 deletions tools/pnnx/src/pass_level2/Tensor_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Tensor_reshape_onnx : public GraphRewriterPass
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 shape
aten::cat op_0 1 1 shape cat dim=0
Reshape op_1 2 1 input cat out allowzero=*
Reshape op_1 2 1 input cat out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -57,46 +57,15 @@ pnnx.Output output 1 0 out
{
return "Tensor.reshape";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx, 19)

class Tensor_reshape_onnx_1 : public Tensor_reshape_onnx
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 shape
aten::cat op_0 1 1 shape cat dim=0
Reshape op_1 2 1 input cat out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_1, 19)

class Tensor_reshape_onnx_2 : public Tensor_reshape_onnx
{
public:
const char* match_pattern_graph() const
void write(Operator* /*op*/, const std::map<std::string, Parameter>& /*captured_params*/) const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 shape
Reshape op_1 2 1 input shape out allowzero=*
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_2, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx, 19)

class Tensor_reshape_onnx_3 : public Tensor_reshape_onnx
class Tensor_reshape_onnx_1 : public Tensor_reshape_onnx
{
public:
const char* match_pattern_graph() const
Expand All @@ -105,23 +74,23 @@ class Tensor_reshape_onnx_3 : public Tensor_reshape_onnx
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 shape
Reshape op_1 2 1 input shape out
Reshape op_0 2 1 input shape out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_3, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_1, 20)

class Tensor_reshape_onnx_4 : public GraphRewriterPass
class Tensor_reshape_onnx_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Reshape op_1 1 1 input out shape=%shape allowzero=*
Reshape op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -130,24 +99,28 @@ pnnx.Output output 1 0 out
{
return "Tensor.reshape";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_4, 20)
bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.shape") == captured_params.end())
return false;

return true;
}

class Tensor_reshape_onnx_5 : public Tensor_reshape_onnx_4
{
public:
const char* match_pattern_graph() const
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Reshape op_1 1 1 input out shape=%shape
pnnx.Output output 1 0 out
)PNNXIR";
if (captured_params.at("op_0.shape").type == 5)
{
op->params["shape"] = captured_params.at("op_0.shape");
}
else // if (captured_params.at("op_0.shape").type == 2)
{
op->params["shape"] = std::vector<int>{captured_params.at("op_0.shape").i};
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_5, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_2, 20)

} // namespace pnnx
17 changes: 10 additions & 7 deletions tools/pnnx/src/pass_level2/torch_squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,23 @@ class torch_squeeze_onnx_1 : public torch_squeeze_onnx
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Squeeze op_0 1 1 input out axes=%axes
Squeeze op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("axes").type == 5 && captured_params.at("axes").ai.size() == 1)
if (captured_params.find("op_0.axes") != captured_params.end())
{
op->params["dim"] = captured_params.at("axes").ai[0];
}
else
{
op->params["dim"] = captured_params.at("axes");
if (captured_params.at("op_0.axes").type == 5 && captured_params.at("op_0.axes").ai.size() == 1)
{
op->params["dim"] = captured_params.at("op_0.axes").ai[0];
}
else
{
op->params["dim"] = captured_params.at("op_0.axes");
}
}
}
};
Expand Down
41 changes: 41 additions & 0 deletions tools/pnnx/src/pass_level2/torch_tile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,45 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tile_onnx, 20)

class torch_tile_onnx_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Tile op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.tile";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.repeats") == captured_params.end())
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("op_0.repeats").type == 5)
{
op->params["dims"] = captured_params.at("op_0.repeats");
}
else // if (captured_params.at("op_0.repeats").type == 2)
{
op->params["dims"] = std::vector<int>{captured_params.at("op_0.repeats").i};
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tile_onnx_1, 20)

} // namespace pnnx
2 changes: 2 additions & 0 deletions tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct constant_as_attribute
};

static constant_as_attribute caas[] = {
{"Expand", 1, "shape"},
{"Gather", 1, "indices"},
{"If", 0, "cond"},
{"Pad", 1, "pads"},
Expand All @@ -49,6 +50,7 @@ static constant_as_attribute caas[] = {
{"Slice", 3, "axes"},
{"Slice", 4, "steps"},
{"Squeeze", 1, "axes"},
{"Tile", 1, "repeats"},
{"Unsqueeze", 1, "axes"},
{"Upsample", 1, "scales"},
};
Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/tests/ncnn/test_torch_unbind.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def forward(self, x, y):

x0 = F.relu(x0)
x1 = F.relu(x1)
x2 = F.relu(x2)
y0 = F.relu(y0)
y1 = F.relu(y1)
y2 = F.relu(y2)
Expand All @@ -35,7 +36,7 @@ def forward(self, x, y):
y6 = F.relu(y6)
y7 = F.relu(y7)
y8 = F.relu(y8)
return x0, x1, y0, y1, y2, y3, y4, y5, y6, y7, y8
return x0, x1, x2, y0, y1, y2, y3, y4, y5, y6, y7, y8

def test():
net = Model()
Expand Down
21 changes: 21 additions & 0 deletions tools/pnnx/tests/onnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,29 @@ pnnx_onnx_add_test(squeezenet1_1)
pnnx_onnx_add_test(swin_t)
pnnx_onnx_add_test(vit_b_32)

pnnx_onnx_add_test(Tensor_expand)
pnnx_onnx_add_test(Tensor_permute)
pnnx_onnx_add_test(Tensor_repeat)
pnnx_onnx_add_test(Tensor_reshape)
pnnx_onnx_add_test(Tensor_select)
pnnx_onnx_add_test(Tensor_slice)
pnnx_onnx_add_test(Tensor_view)

pnnx_onnx_add_test(torch_cat)
pnnx_onnx_add_test(torch_ceil)
pnnx_onnx_add_test(torch_chunk)
pnnx_onnx_add_test(torch_flatten)
pnnx_onnx_add_test(torch_floor)
pnnx_onnx_add_test(torch_max)
pnnx_onnx_add_test(torch_maximum)
pnnx_onnx_add_test(torch_mean)
pnnx_onnx_add_test(torch_min)
pnnx_onnx_add_test(torch_minimum)
pnnx_onnx_add_test(torch_prod)
pnnx_onnx_add_test(torch_split)
pnnx_onnx_add_test(torch_squeeze)
pnnx_onnx_add_test(torch_stack)
pnnx_onnx_add_test(torch_sum)
pnnx_onnx_add_test(torch_transpose)
pnnx_onnx_add_test(torch_unbind)
pnnx_onnx_add_test(torch_unsqueeze)
60 changes: 60 additions & 0 deletions tools/pnnx/tests/onnx/test_Tensor_expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = x.expand(24)
y = y.expand(-1, 11, -1)
z = z.expand(2, 8, 3, -1, 4)
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1)
y = torch.rand(3, 1, 1)
z = torch.rand(1, 8, 1, 9, 1)

a = net(x, y, z)

# export onnx
torch.onnx.export(net, (x, y, z), "test_Tensor_expand.onnx")

# onnx to pnnx
import os
os.system("../../src/pnnx test_Tensor_expand.onnx inputshape=[1],[3,1,1],[1,8,1,9,1]")

# pnnx inference
import test_Tensor_expand_pnnx
b = test_Tensor_expand_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Loading

0 comments on commit 569617f

Please sign in to comment.