Skip to content

Commit

Permalink
pnnx add missing Tensor.to pattern (#4908)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Aug 3, 2023
1 parent 759d55d commit e13fbe2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
37 changes: 31 additions & 6 deletions tools/pnnx/src/pass_level2/Tensor_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ pnnx.Output output 1 0 out

op->params["copy"] = captured_params.at("copy");

if (captured_params.at("memory_format").i == 0)
op->params["memory_format"] = "torch.contiguous_format";
if (captured_params.at("memory_format").i == 1)
op->params["memory_format"] = "torch.preserve_format";
if (captured_params.at("memory_format").i == 2)
op->params["memory_format"] = "torch.channels_last";
if (captured_params.at("memory_format").type == 2)
{
if (captured_params.at("memory_format").i == 0)
op->params["memory_format"] = "torch.contiguous_format";
if (captured_params.at("memory_format").i == 1)
op->params["memory_format"] = "torch.preserve_format";
if (captured_params.at("memory_format").i == 2)
op->params["memory_format"] = "torch.channels_last";
}
}
};

Expand All @@ -83,7 +86,29 @@ pnnx.Output output 1 0 out
}
};

class Tensor_to_2 : public Tensor_to
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
10 9
pnnx.Input input_0 0 1 input
prim::Constant op_0 0 1 dtype value=%dtype
prim::Constant op_1 0 1 layout value=*
prim::Constant op_2 0 1 device value=*
prim::Constant op_3 0 1 pin_memory value=*
prim::Constant op_4 0 1 non_blocking value=*
prim::Constant op_5 0 1 copy value=%copy
prim::Constant op_6 0 1 memory_format value=%memory_format
aten::to op_7 8 1 input dtype layout device pin_memory non_blocking copy memory_format out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_2, 20)

} // namespace pnnx
3 changes: 2 additions & 1 deletion tools/pnnx/tests/test_Tensor_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def forward(self, x, y):
x = x.to(device='cpu', dtype=torch.int, copy=True)
x = x + 1
y = y - 2
return x, y
z = x.to(y.device)
return x, y, z

def test():
net = Model()
Expand Down

0 comments on commit e13fbe2

Please sign in to comment.