diff --git a/CMakeLists.txt b/CMakeLists.txt index db722c925..b82a29609 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ endif() # Set max opset version for onnx if you build from other version of onnx this # should be modified. -add_definitions(-DMAX_ONNX_OPSET_VERSION=19) +add_definitions(-DMAX_ONNX_OPSET_VERSION=23) add_definitions(-DPADDLE2ONNX_LIB) # Internal flags for convert.h.in @@ -95,8 +95,7 @@ else() message(STATUS "Python site-packages directory: ${PYTHON_SITE_PACKAGES}") # set(PADDLE_LIB ${PYTHON_SITE_PACKAGES}/paddle/base/libpaddle.so) - set(PADDLE_LIB - "${PYTHON_SITE_PACKAGES}/paddle/base/libpaddle.so") + set(PADDLE_LIB "${PYTHON_SITE_PACKAGES}/paddle/base/libpaddle.so") if(EXISTS ${PADDLE_LIB}) message(STATUS "found libpaddle.so : ${PADDLE_LIB}") else() diff --git a/paddle2onnx/mapper/activation/activation.cc b/paddle2onnx/mapper/activation/activation.cc index 0740c3b51..12bc578e2 100644 --- a/paddle2onnx/mapper/activation/activation.cc +++ b/paddle2onnx/mapper/activation/activation.cc @@ -23,6 +23,7 @@ REGISTER_MAPPER(asin, ActivationMapper) REGISTER_MAPPER(atan, ActivationMapper) REGISTER_MAPPER(brelu, BReluMapper) REGISTER_MAPPER(ceil, ActivationMapper) +REGISTER_PIR_MAPPER(ceil, ActivationMapper) REGISTER_MAPPER(cos, ActivationMapper) REGISTER_PIR_MAPPER(cos, ActivationMapper) REGISTER_MAPPER(elu, EluMapper) @@ -47,6 +48,7 @@ REGISTER_MAPPER(logsigmoid, LogSigmoidMapper) REGISTER_MAPPER(log_softmax, LogSoftmaxMapper) REGISTER_MAPPER(mish, MishMapper) REGISTER_MAPPER(prelu, PReluMapper) +REGISTER_PIR_MAPPER(prelu, PReluMapper) REGISTER_MAPPER(reciprocal, ActivationMapper) REGISTER_MAPPER(relu, ActivationMapper) REGISTER_PIR_MAPPER(relu, ActivationMapper) @@ -70,6 +72,7 @@ REGISTER_PIR_MAPPER(sqrt, ActivationMapper) REGISTER_MAPPER(square, SquareMapper) REGISTER_MAPPER(tan, ActivationMapper) REGISTER_MAPPER(tanh, ActivationMapper) +REGISTER_PIR_MAPPER(tanh, ActivationMapper) REGISTER_MAPPER(tanh_shrink, TanhShrinkMapper) REGISTER_MAPPER(thresholded_relu, ThresholdedReluMapper) diff --git a/paddle2onnx/mapper/activation/activation.h b/paddle2onnx/mapper/activation/activation.h index a0484da9a..9a9e65ad3 100644 --- a/paddle2onnx/mapper/activation/activation.h +++ b/paddle2onnx/mapper/activation/activation.h @@ -90,6 +90,12 @@ class PReluMapper : public Mapper { int64_t op_id) : Mapper(p, helper, block_id, op_id) {} + PReluMapper(const PaddlePirParser& p, + OnnxHelper* helper, + int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) {} + int32_t GetMinOpsetVersion(bool verbose) override; void Opset7() override; }; diff --git a/paddle2onnx/mapper/detection/multiclass_nms.cc b/paddle2onnx/mapper/detection/multiclass_nms.cc index c897838a4..39dafc476 100644 --- a/paddle2onnx/mapper/detection/multiclass_nms.cc +++ b/paddle2onnx/mapper/detection/multiclass_nms.cc @@ -17,6 +17,7 @@ namespace paddle2onnx { REGISTER_MAPPER(multiclass_nms3, NMSMapper); +REGISTER_PIR_MAPPER(multiclass_nms3, NMSMapper); int32_t NMSMapper::GetMinOpsetVersion(bool verbose) { auto boxes_info = GetInput("BBoxes"); @@ -133,9 +134,18 @@ void NMSMapper::KeepTopK(const std::string& selected_indices) { helper_->Constant({1}, ONNX_NAMESPACE::TensorProto::INT64, keep_top_k_); auto ensemble_value = helper_->MakeNode("Concat", {num_of_boxes, top_k}); AddAttribute(ensemble_value, "axis", int64_t(0)); - auto new_top_k = - helper_->MakeNode("ReduceMin", {ensemble_value->output(0)}); - AddAttribute(new_top_k, "axes", std::vector(1, 0)); + + std::shared_ptr new_top_k; + if (OnnxHelper::GetOpsetVersion() > 13) { + std::string reduce_min_axis = helper_->Constant( + {1}, ONNX_NAMESPACE::TensorProto::INT64, static_cast(0)); + new_top_k = helper_->MakeNode( + "ReduceMin", {ensemble_value->output(0), reduce_min_axis}); + + } else { + new_top_k = helper_->MakeNode("ReduceMin", {ensemble_value->output(0)}); + AddAttribute(new_top_k, "axes", std::vector(1, 0)); + } AddAttribute(new_top_k, "keepdims", int64_t(1)); // the output is topk_scores, topk_score_indices diff --git a/paddle2onnx/mapper/detection/multiclass_nms.h b/paddle2onnx/mapper/detection/multiclass_nms.h index 692275504..ffab0458d 100644 --- a/paddle2onnx/mapper/detection/multiclass_nms.h +++ b/paddle2onnx/mapper/detection/multiclass_nms.h @@ -43,6 +43,27 @@ class NMSMapper : public Mapper { GetAttr("keep_top_k", &keep_top_k_); } + NMSMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + // NMS is a post process operators for object detection + // We have found there're difference between `multi_class_nms3` in + // PaddlePaddle and `NonMaxSuppresion` in ONNX + MarkAsExperimentalOp(); + GetAttr("normalized", &normalized_); + GetAttr("nms_threshold", &nms_threshold_); + GetAttr("score_threshold", &score_threshold_); + GetAttr("nms_eta", &nms_eta_); + // The `nms_top_k` in Paddle and `max_output_boxes_per_class` in ONNX share + // the same meaning But the filter process may not be same Since NMS is just + // a post process for Detection, we are not going to export it with exactly + // same result. We will make a precision performance in COCO or Pascal VOC + // data later. + GetAttr("nms_top_k", &nms_top_k_); + GetAttr("background_label", &background_label_); + GetAttr("keep_top_k", &keep_top_k_); + } + int32_t GetMinOpsetVersion(bool verbose) override; void KeepTopK(const std::string& selected_indices); void Opset10() override; diff --git a/paddle2onnx/mapper/detection/yolo_box.cc b/paddle2onnx/mapper/detection/yolo_box.cc index 88733b919..1d81380a6 100644 --- a/paddle2onnx/mapper/detection/yolo_box.cc +++ b/paddle2onnx/mapper/detection/yolo_box.cc @@ -17,6 +17,7 @@ namespace paddle2onnx { REGISTER_MAPPER(yolo_box, YoloBoxMapper) +REGISTER_PIR_MAPPER(yolo_box, YoloBoxMapper) int32_t YoloBoxMapper::GetMinOpsetVersion(bool verbose) { Logger(verbose, 11) << RequireOpset(11) << std::endl; @@ -29,8 +30,8 @@ void YoloBoxMapper::Opset11() { // handle the float64 input auto x_info = x_info_ori; if (x_info_ori[0].dtype != P2ODataType::FP32) { - x_info[0].name = helper_->AutoCast(x_info_ori[0].name, x_info_ori[0].dtype, - P2ODataType::FP32); + x_info[0].name = helper_->AutoCast( + x_info_ori[0].name, x_info_ori[0].dtype, P2ODataType::FP32); x_info[0].dtype = P2ODataType::FP32; } @@ -58,7 +59,9 @@ void YoloBoxMapper::Opset11() { // ends This is a standared definition in ONNX However not sure all the // inference engines implements `Slice` this way Let's handle this issue // later - x_name = helper_->Slice(x_name, {0, 1, 2, 3}, {0, 0, 0, 0}, + x_name = helper_->Slice(x_name, + {0, 1, 2, 3}, + {0, 0, 0, 0}, {max_int, anchor_num, max_int, max_int}); } @@ -76,10 +79,10 @@ void YoloBoxMapper::Opset11() { // grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) // grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) - auto float_value_0 = - helper_->Constant({}, GetOnnxDtype(x_info[0].dtype), float(0.0)); - auto float_value_1 = - helper_->Constant({}, GetOnnxDtype(x_info[0].dtype), float(1.0)); + auto float_value_0 = helper_->Constant( + {}, GetOnnxDtype(x_info[0].dtype), static_cast(0.0)); + auto float_value_1 = helper_->Constant( + {}, GetOnnxDtype(x_info[0].dtype), static_cast(1.0)); auto scalar_float_w = helper_->Squeeze(float_w, {}); auto scalar_float_h = helper_->Squeeze(float_h, {}); auto grid_x_0 = helper_->MakeNode( @@ -90,8 +93,8 @@ void YoloBoxMapper::Opset11() { "Tile", {grid_x_0->output(0), nchw[2]}); // shape is [w*h] auto grid_y_1 = helper_->MakeNode( "Tile", {grid_y_0->output(0), nchw[3]}); // shape is [h*w] - auto int_value_1 = - helper_->Constant({1}, ONNX_NAMESPACE::TensorProto::INT64, float(1.0)); + auto int_value_1 = helper_->Constant( + {1}, ONNX_NAMESPACE::TensorProto::INT64, static_cast(1.0)); auto grid_shape_x = helper_->MakeNode("Concat", {nchw[2], nchw[3], int_value_1}); auto grid_shape_y = @@ -115,9 +118,10 @@ void YoloBoxMapper::Opset11() { // pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0]) * // scale_x_y + bias_x_y) / w pred_box[:, :, :, :, 1] = (grid_y + // sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h - auto pred_box_xy = - helper_->Slice(transposed_x->output(0), {0, 1, 2, 3, 4}, {0, 0, 0, 0, 0}, - {max_int, max_int, max_int, max_int, 2}); + auto pred_box_xy = helper_->Slice(transposed_x->output(0), + {0, 1, 2, 3, 4}, + {0, 0, 0, 0, 0}, + {max_int, max_int, max_int, max_int, 2}); auto scale_x_y = helper_->Constant({1}, GetOnnxDtype(x_info[0].dtype), scale_x_y_); float bias_x_y_value = (1.0 - scale_x_y_) / 2.0; @@ -157,9 +161,10 @@ void YoloBoxMapper::Opset11() { // anchor_w pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * // anchor_h anchors = helper_->Reshape(anchors, {1, anchor_num, 1, 1, 2}); - auto pred_box_wh = - helper_->Slice(transposed_x->output(0), {0, 1, 2, 3, 4}, {0, 0, 0, 0, 2}, - {max_int, max_int, max_int, max_int, 4}); + auto pred_box_wh = helper_->Slice(transposed_x->output(0), + {0, 1, 2, 3, 4}, + {0, 0, 0, 0, 2}, + {max_int, max_int, max_int, max_int, 4}); pred_box_wh = helper_->MakeNode("Exp", {pred_box_wh})->output(0); pred_box_wh = helper_->MakeNode("Mul", {pred_box_wh, anchors})->output(0); @@ -168,20 +173,23 @@ void YoloBoxMapper::Opset11() { // 1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor // else: // pred_conf = sigmoid(x[:, :, :, :, 4:5]) - auto confidence = - helper_->Slice(transposed_x->output(0), {0, 1, 2, 3, 4}, {0, 0, 0, 0, 4}, - {max_int, max_int, max_int, max_int, 5}); + auto confidence = helper_->Slice(transposed_x->output(0), + {0, 1, 2, 3, 4}, + {0, 0, 0, 0, 4}, + {max_int, max_int, max_int, max_int, 5}); std::string pred_conf = helper_->MakeNode("Sigmoid", {confidence})->output(0); if (iou_aware_) { - auto ioup = helper_->Slice(x_info[0].name, {0, 1, 2, 3}, {0, 0, 0, 0}, + auto ioup = helper_->Slice(x_info[0].name, + {0, 1, 2, 3}, + {0, 0, 0, 0}, {max_int, anchor_num, max_int, max_int}); ioup = helper_->Unsqueeze(ioup, {4}); ioup = helper_->MakeNode("Sigmoid", {ioup})->output(0); float power_value_0 = 1 - iou_aware_factor_; auto power_0 = helper_->Constant({1}, GetOnnxDtype(x_info[0].dtype), power_value_0); - auto power_1 = helper_->Constant({1}, GetOnnxDtype(x_info[0].dtype), - iou_aware_factor_); + auto power_1 = helper_->Constant( + {1}, GetOnnxDtype(x_info[0].dtype), iou_aware_factor_); ioup = helper_->MakeNode("Pow", {ioup, power_1})->output(0); pred_conf = helper_->MakeNode("Pow", {pred_conf, power_0})->output(0); pred_conf = helper_->MakeNode("Mul", {pred_conf, ioup})->output(0); @@ -190,8 +198,8 @@ void YoloBoxMapper::Opset11() { // pred_conf[pred_conf < conf_thresh] = 0. // pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf // pred_box = pred_box * (pred_conf > 0.).astype('float32') - auto value_2 = - helper_->Constant({1}, GetOnnxDtype(x_info[0].dtype), float(2.0)); + auto value_2 = helper_->Constant( + {1}, GetOnnxDtype(x_info[0].dtype), static_cast(2.0)); auto center = helper_->MakeNode("Div", {pred_box_wh, value_2})->output(0); auto min_xy = helper_->MakeNode("Sub", {pred_box_xy, center})->output(0); auto max_xy = helper_->MakeNode("Add", {pred_box_xy, center})->output(0); @@ -203,7 +211,9 @@ void YoloBoxMapper::Opset11() { filter = helper_->AutoCast(filter, P2ODataType::BOOL, x_info[0].dtype); pred_conf = helper_->MakeNode("Mul", {pred_conf, filter})->output(0); auto pred_score = - helper_->Slice(transposed_x->output(0), {0, 1, 2, 3, 4}, {0, 0, 0, 0, 5}, + helper_->Slice(transposed_x->output(0), + {0, 1, 2, 3, 4}, + {0, 0, 0, 0, 5}, {max_int, max_int, max_int, max_int, max_int}); pred_score = helper_->MakeNode("Sigmoid", {pred_score})->output(0); pred_score = helper_->MakeNode("Mul", {pred_score, pred_conf})->output(0); @@ -226,8 +236,8 @@ void YoloBoxMapper::Opset11() { if (!clip_bbox_) { auto out = helper_->MakeNode("Mul", {pred_box, im_whwh})->output(0); - helper_->AutoCast(out, boxes_info[0].name, x_info[0].dtype, - boxes_info[0].dtype); + helper_->AutoCast( + out, boxes_info[0].name, x_info[0].dtype, boxes_info[0].dtype); } else { pred_box = helper_->MakeNode("Mul", {pred_box, im_whwh})->output(0); auto im_wh = helper_->Concat({split_im_hw[1], split_im_hw[0]}, 2); @@ -238,8 +248,8 @@ void YoloBoxMapper::Opset11() { pred_box_xymin_xymax[1] = helper_->MakeNode("Min", {pred_box_xymin_xymax[1], im_wh})->output(0); auto out = helper_->Concat(pred_box_xymin_xymax, 2); - helper_->AutoCast(out, boxes_info[0].name, x_info[0].dtype, - boxes_info[0].dtype); + helper_->AutoCast( + out, boxes_info[0].name, x_info[0].dtype, boxes_info[0].dtype); } auto class_num = @@ -248,7 +258,7 @@ void YoloBoxMapper::Opset11() { helper_->Concat({nchw[0], value_neg_1, class_num}, int64_t(0)); auto score_out = helper_->MakeNode("Reshape", {pred_score, score_out_shape})->output(0); - helper_->AutoCast(score_out, scores_info[0].name, x_info[0].dtype, - scores_info[0].dtype); + helper_->AutoCast( + score_out, scores_info[0].name, x_info[0].dtype, scores_info[0].dtype); } } // namespace paddle2onnx diff --git a/paddle2onnx/mapper/detection/yolo_box.h b/paddle2onnx/mapper/detection/yolo_box.h index 79781814a..383a8c4fc 100644 --- a/paddle2onnx/mapper/detection/yolo_box.h +++ b/paddle2onnx/mapper/detection/yolo_box.h @@ -36,6 +36,20 @@ class YoloBoxMapper : public Mapper { GetAttr("anchors", &anchors_); } + YoloBoxMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + MarkAsExperimentalOp(); + GetAttr("clip_bbox", &clip_bbox_); + GetAttr("iou_aware", &iou_aware_); + GetAttr("conf_thresh", &conf_thresh_); + GetAttr("iou_aware_factor", &iou_aware_factor_); + GetAttr("class_num", &class_num_); + GetAttr("downsample_ratio", &downsample_ratio_); + GetAttr("scale_x_y", &scale_x_y_); + GetAttr("anchors", &anchors_); + } + int32_t GetMinOpsetVersion(bool verbose) override; void Opset11() override; diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index e9aab7e58..179246400 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -203,9 +203,9 @@ void ModelExporter::SetOpsetVersion(const PaddlePirParser& pir_parser, bool opset_is_legal = true; // here int32_t min_opset = GetMinOpsetVersion(pir_parser); - if (min_opset < 7 || min_opset >= MAX_ONNX_OPSET_VERSION) { + if (min_opset < 7 || min_opset > MAX_ONNX_OPSET_VERSION) { P2OLogger(verbose_) << "The Opset Version must be between 7 and " - << MAX_ONNX_OPSET_VERSION - 1 << std::endl; + << MAX_ONNX_OPSET_VERSION << std::endl; opset_is_legal = false; } if (!auto_upgrade_opset) { @@ -387,11 +387,18 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock( // get cf.yeild op input pir::Operation* cf_yield_op = pir_parser.sub_blocks_ops.back(); // std::vector sub_block_outpus; - for (auto oprand : cf_yield_op->operands()) { - pir::Value value = oprand.source(); + for(int32_t idx = 0; idx < cf_yield_op->num_operands(); ++idx) { + pir::Value value = cf_yield_op->operand(idx).source(); auto cond_info = pir_parser.GetSubBlockValueTensorInfo(value); // sub_block_outpus.push_back(cond_info[0].name); temp_outputs.push_back(std::move(MakeValueInfo(cond_info[0]))); + if (value.defining_op() == nullptr) { + value = + pir::Value(pir_parser.while_op_input_value_map[&(*(value.impl()))]); + } + if(value.defining_op()->GetParent() != &block) { + temp_inputs.push_back(std::move(MakeValueInfo(cond_info[0]))); + } } } else { // sub_blocks_ops is empty @@ -524,6 +531,20 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( if_in_subblock, verbose_); } + if(if_in_subblock && !is_while_block) { + for (auto& input_item : inputs) { + for(int32_t idx = 0; idx < outputs.size(); ++idx) { + auto output_item = outputs[idx]; + if(output_item->name() == input_item->name()) { + output_item->set_name(pir_parser.GenOpInputOutputName("yeild")); + temp_helper.MakeNode("Identity", {input_item->name()}, + {output_item->name()}); + outputs[idx] = std::move(output_item); + } + } + } + inputs.clear(); + } for (auto& item : parameters) { *(graph.add_node()) = *(item.get()); } diff --git a/paddle2onnx/mapper/mapper.h b/paddle2onnx/mapper/mapper.h index 6fc3ddcbb..d24f24492 100644 --- a/paddle2onnx/mapper/mapper.h +++ b/paddle2onnx/mapper/mapper.h @@ -77,24 +77,38 @@ class Mapper { } P2OLogger Error() { - auto& op = parser_->GetOpDesc(block_idx_, op_idx_); std::string output_name = ""; - if (op.outputs(0).arguments_size() > 0) { - output_name = op.outputs(0).arguments(0); + std::string op_type = ""; + if (in_pir_mode) { + auto& op = if_in_cf_block ? pir_parser_->sub_blocks_ops[pir_op_idx_] + : pir_parser_->global_blocks_ops[pir_op_idx_]; + op_type = op->name(); + } else { + auto& op = parser_->GetOpDesc(block_idx_, op_idx_); + if (op.outputs(0).arguments_size() > 0) { + output_name = op.outputs(0).arguments(0); + } + op_type = op.type(); } - std::string op_type = op.type(); std::string prefix = "[ERROR][Paddle2ONNX] [" + op_type + ": " + output_name + "]"; return P2OLogger(true, prefix); } P2OLogger Warn() { - auto& op = parser_->GetOpDesc(block_idx_, op_idx_); std::string output_name = ""; - if (op.outputs(0).arguments_size() > 0) { - output_name = op.outputs(0).arguments(0); + std::string op_type = ""; + if (in_pir_mode) { + auto& op = if_in_cf_block ? pir_parser_->sub_blocks_ops[pir_op_idx_] + : pir_parser_->global_blocks_ops[pir_op_idx_]; + op_type = op->name(); + } else { + auto& op = parser_->GetOpDesc(block_idx_, op_idx_); + if (op.outputs(0).arguments_size() > 0) { + output_name = op.outputs(0).arguments(0); + } + op_type = op.type(); } - std::string op_type = op.type(); std::string prefix = "[WARN][Paddle2ONNX] [" + op_type + ": " + output_name + "]"; return P2OLogger(true, prefix); @@ -116,7 +130,15 @@ class Mapper { "[Paddle2ONNX] Only support opset_version in range of [7, " + std::to_string(MAX_ONNX_OPSET_VERSION) + "]."); - if (opset_version == 19) { + if (opset_version == 23) { + Opset23(); + } else if (opset_version == 22) { + Opset22(); + } else if (opset_version == 21) { + Opset21(); + } else if (opset_version == 20) { + Opset20(); + } else if (opset_version == 19) { Opset19(); } else if (opset_version == 18) { Opset18(); @@ -145,6 +167,10 @@ class Mapper { } } + virtual void Opset23() { Opset22(); } + virtual void Opset22() { Opset21(); } + virtual void Opset21() { Opset20(); } + virtual void Opset20() { Opset19(); } virtual void Opset19() { Opset18(); } virtual void Opset18() { Opset17(); } virtual void Opset17() { Opset16(); } @@ -426,7 +452,7 @@ class Mapper { pir_op_idx_, pir_parser_->GetOpInputOutputName2Idx( pir_op_idx_, input_key, true, if_in_cf_block), - data); + data, if_in_cf_block); } else { auto input_info = GetInput(input_key); return parser_->TryGetTensorValue(block_idx_, input_info[0].name, data); @@ -440,7 +466,7 @@ class Mapper { pir_op_idx_, pir_parser_->GetOpInputOutputName2Idx( pir_op_idx_, input_key, true, if_in_cf_block), - data); + data, if_in_cf_block); } else { Assert(false, "Not support in old IR."); } diff --git a/paddle2onnx/mapper/nn/batch_norm.cc b/paddle2onnx/mapper/nn/batch_norm.cc index a40e88635..3235b30d6 100644 --- a/paddle2onnx/mapper/nn/batch_norm.cc +++ b/paddle2onnx/mapper/nn/batch_norm.cc @@ -23,17 +23,38 @@ REGISTER_PIR_MAPPER(batch_norm, BatchNormMapper) void BatchNormMapper::Opset7() { auto input_info = GetInput("X"); - auto scale_info = GetInput("Scale"); - auto bias_info = GetInput("Bias"); - auto mean_info = GetInput("Mean"); + auto mean_info = GetInput("Mean"); auto variance_info = GetInput("Variance"); auto output_info = GetOutput("Y"); - auto node = helper_->MakeNode( - "BatchNormalization", - {input_info[0].name, scale_info[0].name, bias_info[0].name, - mean_info[0].name, variance_info[0].name}, - {output_info[0].name}); + std::string scale_name, bias_name; + int64_t numel = 1; + for (auto s : mean_info[0].shape) { + numel *= s; + } + if (HasInput("Scale")) { + scale_name = GetInput("Scale")[0].name; + } else { + std::vector values(numel, 1); + scale_name = helper_->Constant( + mean_info[0].shape, GetOnnxDtype(mean_info[0].dtype), values); + } + + if (HasInput("Bias")) { + bias_name = GetInput("Bias")[0].name; + } else { + std::vector values(numel, 0); + bias_name = helper_->Constant( + mean_info[0].shape, GetOnnxDtype(mean_info[0].dtype), values); + } + + auto node = helper_->MakeNode("BatchNormalization", + {input_info[0].name, + scale_name, + bias_name, + mean_info[0].name, + variance_info[0].name}, + {output_info[0].name}); if (helper_->GetOpsetVersion() < 9) { int64_t spatial = 1; AddAttribute(node, "spatial", spatial); diff --git a/paddle2onnx/mapper/nn/deform_conv2d.cc b/paddle2onnx/mapper/nn/deform_conv2d.cc index 4c340a660..484720e1a 100644 --- a/paddle2onnx/mapper/nn/deform_conv2d.cc +++ b/paddle2onnx/mapper/nn/deform_conv2d.cc @@ -19,8 +19,10 @@ namespace paddle2onnx { REGISTER_MAPPER(deformable_conv, DeformConv2dMapper) +REGISTER_PIR_MAPPER(deformable_conv, DeformConv2dMapper) int32_t DeformConv2dMapper::GetMinOpsetVersion(bool verbose) { + Logger(verbose, 19) << RequireOpset(19) << std::endl; return 19; } @@ -30,11 +32,16 @@ void DeformConv2dMapper::Opset19() { auto offset_info = GetInput("Offset"); auto mask_info = GetInput("Mask"); auto output_info = GetOutput("Output"); - std::string bias_name = helper_->Constant({kernel_info[0].shape[0]}, GetOnnxDtype(input_info[0].dtype), static_cast(0.0)); - auto node = helper_->MakeNode( - "DeformConv", - {input_info[0].name, kernel_info[0].name, offset_info[0].name, bias_name, mask_info[0].name}, - {output_info[0].name}); + std::string bias_name = helper_->Constant({kernel_info[0].shape[0]}, + GetOnnxDtype(input_info[0].dtype), + static_cast(0.0)); + auto node = helper_->MakeNode("DeformConv", + {input_info[0].name, + kernel_info[0].name, + offset_info[0].name, + bias_name, + mask_info[0].name}, + {output_info[0].name}); AddAttribute(node, "dilations", dilations_); AddAttribute(node, "group", groups_); @@ -55,4 +62,4 @@ void DeformConv2dMapper::Opset19() { AddAttribute(node, "pads", paddings); AddAttribute(node, "strides", strides_); } -} // namespace paddle2onnx +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/nn/deform_conv2d.h b/paddle2onnx/mapper/nn/deform_conv2d.h index 468b704a6..277b0afca 100644 --- a/paddle2onnx/mapper/nn/deform_conv2d.h +++ b/paddle2onnx/mapper/nn/deform_conv2d.h @@ -18,34 +18,46 @@ #include "paddle2onnx/mapper/mapper.h" -namespace paddle2onnx -{ - - class DeformConv2dMapper : public Mapper - { - public: - DeformConv2dMapper(const PaddleParser &p, OnnxHelper *helper, int64_t block_id, - int64_t op_id) - : Mapper(p, helper, block_id, op_id) - { - GetAttr("deformable_groups", &deformable_groups_); - GetAttr("strides", &strides_); - GetAttr("paddings", &paddings_); - GetAttr("dilations", &dilations_); - GetAttr("groups", &groups_); - GetAttr("im2col_step", &im2col_step_); - } - - int32_t GetMinOpsetVersion(bool verbose) override; - void Opset19() override; - - private: - std::vector strides_; - std::vector paddings_; - std::vector dilations_; - int64_t deformable_groups_; - int64_t groups_; - int64_t im2col_step_; - }; - -} // namespace paddle2onnx +namespace paddle2onnx { + +class DeformConv2dMapper : public Mapper { + public: + DeformConv2dMapper(const PaddleParser &p, + OnnxHelper *helper, + int64_t block_id, + int64_t op_id) + : Mapper(p, helper, block_id, op_id) { + GetAttr("deformable_groups", &deformable_groups_); + GetAttr("strides", &strides_); + GetAttr("paddings", &paddings_); + GetAttr("dilations", &dilations_); + GetAttr("groups", &groups_); + GetAttr("im2col_step", &im2col_step_); + } + + DeformConv2dMapper(const PaddlePirParser &p, + OnnxHelper *helper, + int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + GetAttr("deformable_groups", &deformable_groups_); + GetAttr("strides", &strides_); + GetAttr("paddings", &paddings_); + GetAttr("dilations", &dilations_); + GetAttr("groups", &groups_); + GetAttr("im2col_step", &im2col_step_); + } + + int32_t GetMinOpsetVersion(bool verbose) override; + void Opset19() override; + + private: + std::vector strides_; + std::vector paddings_; + std::vector dilations_; + int64_t deformable_groups_; + int64_t groups_; + int64_t im2col_step_; +}; + +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/nn/pad3d.cc b/paddle2onnx/mapper/nn/pad3d.cc index f88c6c2a6..eea10e847 100644 --- a/paddle2onnx/mapper/nn/pad3d.cc +++ b/paddle2onnx/mapper/nn/pad3d.cc @@ -23,16 +23,12 @@ int32_t Pad3DMapper::GetMinOpsetVersion(bool verbose) { Error() << "NDHWC format is not supported." << std::endl; return -1; } - if (mode_ == "circular") { - Error() << "Padding mode `circular` is not supported." << std::endl; - return -1; - } if (HasInput("Paddings")) { if (!IsConstantInput("Paddings")) { Logger(verbose, 11) << "While Paddings is input and it's not a constant tensor, " << RequireOpset(11) << std::endl; - + if (mode_ == "circular") return 19; return 11; } std::vector paddings; @@ -40,6 +36,7 @@ int32_t Pad3DMapper::GetMinOpsetVersion(bool verbose) { Logger(verbose, 11) << "Cannot get constant value from input of Paddings, " << RequireOpset(11) << std::endl; + if (mode_ == "circular") return 19; return 11; } else { if (paddings.size() != 6) { @@ -48,6 +45,7 @@ int32_t Pad3DMapper::GetMinOpsetVersion(bool verbose) { << paddings.size() << std::endl; return -1; } + if (mode_ == "circular") return 19; } } else { if (paddings_.size() != 6) { @@ -106,7 +104,50 @@ void Pad3DMapper::Opset11() { std::string paddings = ""; if (HasInput("Paddings")) { std::vector paddings_value; - if (TryGetInputValue("Paddings", &paddings_value)) { + if (!in_pir_mode && TryGetInputValue("Paddings", &paddings_value)) { + std::vector new_paddings = + ConvertPaddingParameter(paddings_value); + paddings = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, + new_paddings); + } else { + auto pad_info = GetInput("Paddings"); + auto cast_pad = helper_->AutoCast(pad_info[0].name, + pad_info[0].dtype, + P2ODataType::INT64); + auto split_pads = helper_->Split(cast_pad, std::vector(6, 1), 0); + auto zero = helper_->Constant({1}, + ONNX_NAMESPACE::TensorProto::INT64, + int64_t(0)); + paddings = helper_->Concat({zero, zero, split_pads[4], split_pads[2], + split_pads[0], zero, zero, split_pads[5], + split_pads[3], split_pads[1]}, 0); + } + } else { + std::vector new_paddings = ConvertPaddingParameter(paddings_); + paddings = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, + new_paddings); + } + auto value = helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), value_); + auto node = helper_->MakeNode("Pad", + {input_info[0].name, paddings, value}, + {output_info[0].name}); + AddAttribute(node, "mode", mode); +} + +void Pad3DMapper::Opset19() { + auto input_info = GetInput("X"); + auto output_info = GetOutput("Out"); + auto mode = mode_; + if (mode == "replicate") { + mode = "edge"; + } else if (mode == "circular") { + mode = "wrap"; + } + + std::string paddings = ""; + if (HasInput("Paddings")) { + std::vector paddings_value; + if (!in_pir_mode && TryGetInputValue("Paddings", &paddings_value)) { std::vector new_paddings = ConvertPaddingParameter(paddings_value); paddings = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, diff --git a/paddle2onnx/mapper/nn/pad3d.h b/paddle2onnx/mapper/nn/pad3d.h index 0750f81c9..f4f78114f 100644 --- a/paddle2onnx/mapper/nn/pad3d.h +++ b/paddle2onnx/mapper/nn/pad3d.h @@ -43,6 +43,7 @@ class Pad3DMapper : public Mapper { int32_t GetMinOpsetVersion(bool verbose) override; void Opset7() override; void Opset11() override; + void Opset19() override; private: std::vector diff --git a/paddle2onnx/mapper/nn/pool2d.cc b/paddle2onnx/mapper/nn/pool2d.cc index 85ad45e5b..5ab6accd9 100755 --- a/paddle2onnx/mapper/nn/pool2d.cc +++ b/paddle2onnx/mapper/nn/pool2d.cc @@ -59,16 +59,17 @@ void Pool2dMapper::AdaptivePool(const std::vector& input_info, onnx_pool_type = iter->second[0]; std::shared_ptr node(nullptr); - if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) - { - node = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name}); - } - else - { - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != + kNoNeedCastTypesOpSet7.end()) { + node = helper_->MakeNode( + onnx_pool_type, {input_info[0].name}, {output_info[0].name}); + } else { + auto input = helper_->AutoCast( + input_info[0].name, input_info[0].dtype, P2ODataType::FP32); node = helper_->MakeNode(onnx_pool_type, {input}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, + helper_->AutoCast(node->output(0), + output_info[0].name, + P2ODataType::FP32, output_info[0].dtype); } @@ -118,14 +119,14 @@ void Pool2dMapper::NoAdaptivePool(const std::vector& input_info, int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_)); int64_t max_pads = *std::max_element(std::begin(pads_), std::end(pads_)); std::string input_x = input_info[0].name; - if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) == kNoNeedCastTypesOpSet7.end()) - { - input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) == + kNoNeedCastTypesOpSet7.end()) { + input_x = helper_->AutoCast( + input_info[0].name, input_info[0].dtype, P2ODataType::FP32); } if (max_ksize <= max_pads) { - std::vector onnx_paddings = {0, 0, pads_[0], pads_[1], - 0, 0, pads_[2], pads_[3]}; + std::vector onnx_paddings = { + 0, 0, pads_[0], pads_[1], 0, 0, pads_[2], pads_[3]}; std::vector inputs_names = {input_x}; if (helper_->GetOpsetVersion() >= 11) { std::string paddings_node = @@ -159,14 +160,14 @@ void Pool2dMapper::NoAdaptivePool(const std::vector& input_info, Assert(iter != op_mapper_.end(), "Pooling not found"); onnx_pool_type = iter->second[0]; std::shared_ptr node(nullptr); - if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) - { + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != + kNoNeedCastTypesOpSet7.end()) { node = helper_->MakeNode(onnx_pool_type, {input_x}, {output_info[0].name}); - } - else - { + } else { node = helper_->MakeNode(onnx_pool_type, {input_x}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, + helper_->AutoCast(node->output(0), + output_info[0].name, + P2ODataType::FP32, output_info[0].dtype); } @@ -182,12 +183,14 @@ void Pool2dMapper::NoAdaptivePool(const std::vector& input_info, } else { AddAttribute(node, "pads", pads_); } - // TODO: Need double check - // if (OpType() != "max_pool2d_with_index" && helper_->GetOpsetVersion() >= 10) { + // TODO(qinzhongyu): Need double check + // if (OpType() != "max_pool2d_with_index" && helper_->GetOpsetVersion() >= + // 10) { // AddAttribute(node, "ceil_mode", static_cast(ceil_mode_)); // } // if (OpType() != "max_pool2d_with_index" && pooling_type_ == "avg") { - // AddAttribute(node, "count_include_pad", static_cast(exclusive_)); + // AddAttribute(node, "count_include_pad", + // static_cast(exclusive_)); // } if (helper_->GetOpsetVersion() >= 10) { AddAttribute(node, "ceil_mode", static_cast(ceil_mode_)); @@ -206,15 +209,17 @@ int32_t Pool2dMapper::GetMinOpsetVersion(bool verbose) { auto input_info = GetInput("X"); auto output_info = GetOutput("Out"); if (in_pir_mode) { - // TODO: For PIR, kernel size is in inputs + // TODO(qinzhongyu): For PIR, kernel size is in inputs auto ksize = GetInput("ksize")[0]; - for (auto i = 0; i < ksize.shape.size(); ++ i) { - k_size_.push_back(ksize.shape[i]); - } + Assert(IsConstantInput("ksize"), "ksize's type is not constant."); + // for (auto i = 0; i < ksize.shape.size(); ++ i) { + // k_size_.push_back(ksize.shape[i]); + // } + TryGetInputValue("ksize", &k_size_); } else { if (IsAttrVar("ksize")) { Error() << "While Attribute(ksize)'s type is Tensor, it's not " - "supported." + "supported." << std::endl; return -1; } else { @@ -244,7 +249,8 @@ int32_t Pool2dMapper::GetMinOpsetVersion(bool verbose) { int64_t input_w = input_info[0].shape[3]; int64_t output_h = output_info[0].shape[2]; int64_t output_w = output_info[0].shape[3]; - if (output_h == -1 || output_w == -1 || !IsSameSpan(input_h, output_h) || !IsSameSpan(input_w, output_w)) { + if (output_h == -1 || output_w == -1 || !IsSameSpan(input_h, output_h) || + !IsSameSpan(input_w, output_w)) { Error() << "Cannot convert adaptive pool with input_size: " << input_h << " " << input_h << " output_size: " << output_h << " " << output_w << std::endl; @@ -282,11 +288,10 @@ void Pool2dMapper::Opset7() { */ // k_size_ = GetInputAttrVar("ksize", "value"); TryGetInputValue("ksize", &k_size_); - } else{ + } else { GetAttr("ksize", &k_size_); } - bool is_1x1_kernel = true; for (auto i : k_size_) { if (i != 1) { @@ -302,17 +307,16 @@ void Pool2dMapper::Opset7() { auto iter = op_mapper_.find(pooling_type_); onnx_pool_type = iter->second[1]; } - if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) - { - auto output = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name}); - } - else - { - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != + kNoNeedCastTypesOpSet7.end()) { + auto output = helper_->MakeNode( + onnx_pool_type, {input_info[0].name}, {output_info[0].name}); + } else { + auto input = helper_->AutoCast( + input_info[0].name, input_info[0].dtype, P2ODataType::FP32); auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0); - helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + helper_->AutoCast( + output, output_info[0].name, P2ODataType::FP32, output_info[0].dtype); } } else if (adaptive_) { AdaptivePool(input_info, output_info); diff --git a/paddle2onnx/mapper/tensor/argsort.cc b/paddle2onnx/mapper/tensor/argsort.cc index 3d9f2870f..0d0ce8e02 100644 --- a/paddle2onnx/mapper/tensor/argsort.cc +++ b/paddle2onnx/mapper/tensor/argsort.cc @@ -16,6 +16,7 @@ namespace paddle2onnx { REGISTER_MAPPER(argsort, ArgsortMapper) +REGISTER_PIR_MAPPER(argsort, ArgsortMapper) int32_t ArgsortMapper::GetMinOpsetVersion(bool verbose) { if (!descending_) { diff --git a/paddle2onnx/mapper/tensor/argsort.h b/paddle2onnx/mapper/tensor/argsort.h index b8ff34a66..30de46d89 100644 --- a/paddle2onnx/mapper/tensor/argsort.h +++ b/paddle2onnx/mapper/tensor/argsort.h @@ -25,6 +25,13 @@ class ArgsortMapper : public Mapper { GetAttr("descending", &descending_); GetAttr("axis", &axis_); } + + ArgsortMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + GetAttr("descending", &descending_); + GetAttr("axis", &axis_); + } int32_t GetMinOpsetVersion(bool verbose) override; void Opset10() override; void Opset7() override; diff --git a/paddle2onnx/mapper/tensor/bitwise.cc b/paddle2onnx/mapper/tensor/bitwise.cc index febe313dc..58551090f 100644 --- a/paddle2onnx/mapper/tensor/bitwise.cc +++ b/paddle2onnx/mapper/tensor/bitwise.cc @@ -20,36 +20,46 @@ REGISTER_MAPPER(bitwise_and, BitWiseMapper) REGISTER_MAPPER(bitwise_not, BitWiseMapper) REGISTER_MAPPER(bitwise_or, BitWiseMapper) REGISTER_MAPPER(bitwise_xor, BitWiseMapper) +REGISTER_PIR_MAPPER(bitwise_and, BitWiseMapper) +REGISTER_PIR_MAPPER(bitwise_not, BitWiseMapper) +REGISTER_PIR_MAPPER(bitwise_or, BitWiseMapper) +REGISTER_PIR_MAPPER(bitwise_xor, BitWiseMapper) int32_t BitWiseMapper::GetMinOpsetVersion(bool verbose) { auto x_info = GetInput("X"); - if(x_info[0].dtype == P2ODataType::BOOL){ + if (x_info[0].dtype == P2ODataType::BOOL) { Logger(verbose, 7) << RequireOpset(7) << std::endl; return 7; } Logger(verbose, 18) << RequireOpset(18) << std::endl; return 18; } -void BitWiseMapper::Opset7() { +void BitWiseMapper::Opset7() { auto x_info = GetInput("X"); auto out_info = GetOutput("Out"); - if (paddle_type_ == "bitwise_not"){ - helper_->MakeNode(onnx_elemwise_type_, {x_info[0].name}, {out_info[0].name}); - } else{ + if (paddle_type_ == "bitwise_not") { + helper_->MakeNode( + onnx_elemwise_type_, {x_info[0].name}, {out_info[0].name}); + } else { auto y_info = GetInput("Y"); - helper_->MakeNode(onnx_elemwise_type_, {x_info[0].name, y_info[0].name}, {out_info[0].name}); + helper_->MakeNode(onnx_elemwise_type_, + {x_info[0].name, y_info[0].name}, + {out_info[0].name}); } } void BitWiseMapper::Opset18() { auto x_info = GetInput("X"); auto out_info = GetOutput("Out"); - std::string node_name = x_info[0].dtype == P2ODataType::BOOL? onnx_elemwise_type_: onnx_bitwise_type_; - if(paddle_type_ == "bitwise_not"){ + std::string node_name = x_info[0].dtype == P2ODataType::BOOL + ? onnx_elemwise_type_ + : onnx_bitwise_type_; + if (paddle_type_ == "bitwise_not") { helper_->MakeNode(node_name, {x_info[0].name}, {out_info[0].name}); - } else{ + } else { auto y_info = GetInput("Y"); - helper_->MakeNode(node_name, {x_info[0].name, y_info[0].name},{out_info[0].name}); + helper_->MakeNode( + node_name, {x_info[0].name, y_info[0].name}, {out_info[0].name}); } } -}// namespace paddle2onnx \ No newline at end of file +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/bitwise.h b/paddle2onnx/mapper/tensor/bitwise.h index 81819e4b2..b6f269bb9 100644 --- a/paddle2onnx/mapper/tensor/bitwise.h +++ b/paddle2onnx/mapper/tensor/bitwise.h @@ -13,36 +13,52 @@ // limitations under the License. #pragma once +#include #include #include -#include +#include "paddle2onnx/mapper/exporter.h" #include "paddle2onnx/mapper/mapper.h" -namespace paddle2onnx -{ +namespace paddle2onnx { + +class BitWiseMapper : public Mapper { + public: + BitWiseMapper(const PaddleParser &p, + OnnxHelper *helper, + int64_t block_id, + int64_t op_id) + : Mapper(p, helper, block_id, op_id) { + op_mapper_["bitwise_and"] = "BitwiseAnd"; + op_mapper_["bitwise_not"] = "BitwiseNot"; + op_mapper_["bitwise_or"] = "BitwiseOr"; + op_mapper_["bitwise_xor"] = "BitwiseXor"; + paddle_type_ = OpType(); + onnx_bitwise_type_ = op_mapper_.find(paddle_type_)->second; + onnx_elemwise_type_ = onnx_bitwise_type_.substr(7); + } - class BitWiseMapper : public Mapper { - public: - BitWiseMapper(const PaddleParser &p, OnnxHelper *helper, int64_t block_id, - int64_t op_id) - : Mapper(p, helper, block_id, op_id) { - op_mapper_["bitwise_and"] = "BitwiseAnd"; - op_mapper_["bitwise_not"] = "BitwiseNot"; - op_mapper_["bitwise_or"] = "BitwiseOr"; - op_mapper_["bitwise_xor"] = "BitwiseXor"; - paddle_type_ = OpType(); - onnx_bitwise_type_ = op_mapper_.find(paddle_type_)->second; - onnx_elemwise_type_ = onnx_bitwise_type_.substr(7); - } - int32_t GetMinOpsetVersion(bool verbose) override; - void Opset7() override; - void Opset18() override; + BitWiseMapper(const PaddlePirParser &p, + OnnxHelper *helper, + int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + op_mapper_["bitwise_and"] = "BitwiseAnd"; + op_mapper_["bitwise_not"] = "BitwiseNot"; + op_mapper_["bitwise_or"] = "BitwiseOr"; + op_mapper_["bitwise_xor"] = "BitwiseXor"; + paddle_type_ = convert_pir_op_name(OpType()); + onnx_bitwise_type_ = op_mapper_.find(paddle_type_)->second; + onnx_elemwise_type_ = onnx_bitwise_type_.substr(7); + } + int32_t GetMinOpsetVersion(bool verbose) override; + void Opset7() override; + void Opset18() override; - private: - std::map op_mapper_; - std::string onnx_bitwise_type_; - std::string onnx_elemwise_type_; - std::string paddle_type_; - }; + private: + std::map op_mapper_; + std::string onnx_bitwise_type_; + std::string onnx_elemwise_type_; + std::string paddle_type_; +}; -} // namespace paddle2onnx +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/builtin_combine.cc b/paddle2onnx/mapper/tensor/builtin_combine.cc index 4101edf23..e3ada212a 100644 --- a/paddle2onnx/mapper/tensor/builtin_combine.cc +++ b/paddle2onnx/mapper/tensor/builtin_combine.cc @@ -21,30 +21,32 @@ namespace paddle2onnx { REGISTER_PIR_MAPPER(builtin_combine, BuiltinCombineMapper) int64_t BuiltinCombineMapper::GetInputNum() { - auto& op = pir_parser_->global_blocks_ops[pir_op_idx_]; - PADDLE_ENFORCE_EQ( - op->isa(), - true, - common::errors::InvalidArgument( - "The operator type must be builtin.combine, but the actual operator type is %s.", - op->name())); - return op->dyn_cast().inputs().size(); + auto& op = if_in_cf_block ? pir_parser_->sub_blocks_ops[pir_op_idx_] + : pir_parser_->global_blocks_ops[pir_op_idx_]; + PADDLE_ENFORCE_EQ(op->isa(), + true, + common::errors::InvalidArgument( + "The operator type must be builtin.combine, but the " + "actual operator type is %s.", + op->name())); + return op->dyn_cast().inputs().size(); } void BuiltinCombineMapper::Opset7() { - auto output_info = GetOutput(0); - int64_t input_num = GetInputNum(); - PADDLE_ENFORCE_EQ( - input_num == output_info.size(), - true, - common::errors::InvalidArgument( - "The number of inputs and outputs must be the same, but the actual " - "input number is %d and output number is %d.", - input_num, output_info.size())); - for(int64_t i = 0; i < input_num; ++i) { - auto input_info = GetInput(i); - helper_->MakeNode("Identity", {input_info[0].name}, {output_info[i].name}); - } + auto output_info = GetOutput(0); + int64_t input_num = GetInputNum(); + PADDLE_ENFORCE_EQ( + input_num == output_info.size(), + true, + common::errors::InvalidArgument( + "The number of inputs and outputs must be the same, but the actual " + "input number is %d and output number is %d.", + input_num, + output_info.size())); + for (int64_t i = 0; i < input_num; ++i) { + auto input_info = GetInput(i); + helper_->MakeNode("Identity", {input_info[0].name}, {output_info[i].name}); + } } } // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/builtin_split.cc b/paddle2onnx/mapper/tensor/builtin_split.cc index 666754c3d..1305dae26 100644 --- a/paddle2onnx/mapper/tensor/builtin_split.cc +++ b/paddle2onnx/mapper/tensor/builtin_split.cc @@ -18,30 +18,43 @@ namespace paddle2onnx { REGISTER_PIR_MAPPER(builtin_split, BuiltinSplitMapper) int64_t BuiltinSplitMapper::GetOutputNum() { - auto& op = pir_parser_->global_blocks_ops[pir_op_idx_]; - PADDLE_ENFORCE_EQ( - op->isa(), - true, - common::errors::InvalidArgument( - "The operator type must be builtin.split, but the actual operator type is %s.", - op->name())); - return op->dyn_cast().outputs().size(); + auto& op = if_in_cf_block ? pir_parser_->sub_blocks_ops[pir_op_idx_] + : pir_parser_->global_blocks_ops[pir_op_idx_]; + return op->dyn_cast().outputs().size(); +} + +bool BuiltinSplitMapper::IsEinsumOut() { + auto& op = if_in_cf_block ? pir_parser_->sub_blocks_ops[pir_op_idx_] + : pir_parser_->global_blocks_ops[pir_op_idx_]; + PADDLE_ENFORCE_EQ(op->isa(), + true, + common::errors::InvalidArgument( + "The operator type must be builtin.split, but the " + "actual operator type is %s.", + op->name())); + if (op->operand_source(0).defining_op()->name() == "pd_op.einsum") { + Warn() << "Skip builtin.split." << std::endl; + return true; + } + return false; } void BuiltinSplitMapper::Opset7() { - auto input_info = GetInput(0); - int64_t output_num = GetOutputNum(); - PADDLE_ENFORCE_EQ( - output_num == input_info.size(), - true, - common::errors::InvalidArgument( - "The number of inputs and outputs must be the same, but the actual " - "input number is %d and output number is %d.", - input_info.size(), output_num)); - for(int64_t i = 0; i < output_num; ++i) { - auto output_info = GetOutput(i); - helper_->MakeNode("Identity", {input_info[i].name}, {output_info[0].name}); - } + if (IsEinsumOut()) return; + auto input_info = GetInput(0); + int64_t output_num = GetOutputNum(); + PADDLE_ENFORCE_EQ( + output_num == input_info.size(), + true, + common::errors::InvalidArgument( + "The number of inputs and outputs must be the same, but the actual " + "input number is %d and output number is %d.", + input_info.size(), + output_num)); + for (int64_t i = 0; i < output_num; ++i) { + auto output_info = GetOutput(i); + helper_->MakeNode("Identity", {input_info[i].name}, {output_info[0].name}); + } } } // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/builtin_split.h b/paddle2onnx/mapper/tensor/builtin_split.h index 6aea704df..3e4ea26f7 100644 --- a/paddle2onnx/mapper/tensor/builtin_split.h +++ b/paddle2onnx/mapper/tensor/builtin_split.h @@ -33,6 +33,7 @@ class BuiltinSplitMapper : public Mapper { private: int64_t GetOutputNum(); + bool IsEinsumOut(); }; } // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/greater_than.h b/paddle2onnx/mapper/tensor/greater_than.h index 50bb5b32a..1b30bbe32 100644 --- a/paddle2onnx/mapper/tensor/greater_than.h +++ b/paddle2onnx/mapper/tensor/greater_than.h @@ -26,7 +26,7 @@ class GreaterThanMapper : public Mapper { OnnxHelper* helper, int64_t op_id, bool c) - : Mapper(p, helper, op_id) { + : Mapper(p, helper, op_id, c) { in_pir_mode = true; } void Opset7() override; diff --git a/paddle2onnx/mapper/tensor/lookup_table.cc b/paddle2onnx/mapper/tensor/lookup_table.cc index 70a158dc6..e17c9c0c5 100644 --- a/paddle2onnx/mapper/tensor/lookup_table.cc +++ b/paddle2onnx/mapper/tensor/lookup_table.cc @@ -21,6 +21,7 @@ namespace paddle2onnx { REGISTER_MAPPER(lookup_table, LookupTableMapper) REGISTER_MAPPER(lookup_table_v2, LookupTableMapper) +REGISTER_PIR_MAPPER(embedding, LookupTableMapper) int32_t LookupTableMapper::GetMinOpsetVersion(bool verbose) { auto input_w_info = GetInput("W"); @@ -67,11 +68,11 @@ void LookupTableMapper::Opset7() { input_shape, GetOnnxDtype(input_w_info[0].dtype), data); auto weight_node = helper_->MakeNode("Mul", {input_w_info[0].name, constant}); - helper_->MakeNode("Gather", {weight_node->output(0), ids_node}, - {output_info[0].name}); + helper_->MakeNode( + "Gather", {weight_node->output(0), ids_node}, {output_info[0].name}); } else { - helper_->MakeNode("Gather", {input_w_info[0].name, ids_node}, - {output_info[0].name}); + helper_->MakeNode( + "Gather", {input_w_info[0].name, ids_node}, {output_info[0].name}); } } @@ -109,8 +110,8 @@ void LookupTableMapper::Opset11() { {1}, ONNX_NAMESPACE::TensorProto::INT64, padding_idx_); auto scatter_node = helper_->MakeNode( "ScatterND", {input_w_info[0].name, index, replace_data}); - helper_->MakeNode("Gather", {scatter_node->output(0), ids_node}, - {output_info[0].name}); + helper_->MakeNode( + "Gather", {scatter_node->output(0), ids_node}, {output_info[0].name}); } else { std::vector data(sum_val, 1); for (auto i = 0; i < interval; i++) { @@ -120,12 +121,12 @@ void LookupTableMapper::Opset11() { input_shape, GetOnnxDtype(input_w_info[0].dtype), data); auto weight_node = helper_->MakeNode("Mul", {input_w_info[0].name, constant}); - helper_->MakeNode("Gather", {weight_node->output(0), ids_node}, - {output_info[0].name}); + helper_->MakeNode( + "Gather", {weight_node->output(0), ids_node}, {output_info[0].name}); } } else { - helper_->MakeNode("Gather", {input_w_info[0].name, ids_node}, - {output_info[0].name}); + helper_->MakeNode( + "Gather", {input_w_info[0].name, ids_node}, {output_info[0].name}); } } diff --git a/paddle2onnx/mapper/tensor/lookup_table.h b/paddle2onnx/mapper/tensor/lookup_table.h index 3b964e55e..60412b04e 100644 --- a/paddle2onnx/mapper/tensor/lookup_table.h +++ b/paddle2onnx/mapper/tensor/lookup_table.h @@ -22,12 +22,21 @@ namespace paddle2onnx { class LookupTableMapper : public Mapper { public: - LookupTableMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id, + LookupTableMapper(const PaddleParser& p, + OnnxHelper* helper, + int64_t block_id, int64_t op_id) : Mapper(p, helper, block_id, op_id) { GetAttr("padding_idx", &padding_idx_); } + LookupTableMapper(const PaddlePirParser& p, + OnnxHelper* helper, + int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + GetAttr("padding_idx", &padding_idx_); + } int32_t GetMinOpsetVersion(bool verbose) override; void Opset7() override; void Opset11() override; diff --git a/paddle2onnx/mapper/tensor/multinomial.cc b/paddle2onnx/mapper/tensor/multinomial.cc new file mode 100644 index 000000000..b437d044e --- /dev/null +++ b/paddle2onnx/mapper/tensor/multinomial.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle2onnx/mapper/tensor/multinomial.h" +#include +#include +#include + +namespace paddle2onnx { +REGISTER_PIR_MAPPER(multinomial, MultinomialMapper) + +int32_t MultinomialMapper::GetMinOpsetVersion(bool verbose) { + if (!IsConstantInput("num_samples")) { + Error() << "num_samples is not a constant input." << std::endl; + return -1; + } + return 7; +} + +void MultinomialMapper::Opset7() { + auto x_info = GetInput("x"); + auto out_info = GetOutput("out"); + double num_samples = 1; + TryGetInputValue("num_samples", &num_samples); + auto node = + helper_->MakeNode("Multinomial", {x_info[0].name}, {out_info[0].name}); + AddAttribute(node, "dtype", GetOnnxDtype(out_info[0].dtype)); + AddAttribute(node, "sample_size", static_cast(num_samples)); +} + +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/multinomial.h b/paddle2onnx/mapper/tensor/multinomial.h new file mode 100644 index 000000000..d95da97ed --- /dev/null +++ b/paddle2onnx/mapper/tensor/multinomial.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle2onnx/mapper/mapper.h" + +namespace paddle2onnx { + +class MultinomialMapper : public Mapper { + public: + MultinomialMapper(const PaddlePirParser& p, + OnnxHelper* helper, + int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + GetAttr("replacement", &replacement_); + } + + int32_t GetMinOpsetVersion(bool verbose) override; + void Opset7() override; + + private: + bool replacement_; +}; +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/one_hot_v2.cc b/paddle2onnx/mapper/tensor/one_hot_v2.cc index 29df5665d..8de49926f 100644 --- a/paddle2onnx/mapper/tensor/one_hot_v2.cc +++ b/paddle2onnx/mapper/tensor/one_hot_v2.cc @@ -16,6 +16,7 @@ namespace paddle2onnx { REGISTER_MAPPER(one_hot_v2, OneHotV2Mapper) +REGISTER_PIR_MAPPER(one_hot, OneHotV2Mapper) int32_t OneHotV2Mapper::GetMinOpsetVersion(bool verbose) { if (allow_out_of_range_) { @@ -24,7 +25,7 @@ int32_t OneHotV2Mapper::GetMinOpsetVersion(bool verbose) { return -1; } auto output_info = GetOutput("Out"); - if (output_info[0].dtype != dtype_) { + if (!in_pir_mode && output_info[0].dtype != dtype_) { Error() << "dtype attribute and output dtype do not match." << std::endl; return -1; } @@ -47,6 +48,9 @@ void OneHotV2Mapper::Opset9() { if (HasInput("depth_tensor")) { auto input_depth_info = GetInput("depth_tensor"); depth_node = input_depth_info[0].name; + } else if (HasInput("num_classes")) { + auto input_depth_info = GetInput("num_classes"); + depth_node = input_depth_info[0].name; } else { depth_node = helper_->Constant({1}, GetOnnxDtype(input_info[0].dtype), depth_); diff --git a/paddle2onnx/mapper/tensor/one_hot_v2.h b/paddle2onnx/mapper/tensor/one_hot_v2.h index dad4b9da7..0014fe345 100644 --- a/paddle2onnx/mapper/tensor/one_hot_v2.h +++ b/paddle2onnx/mapper/tensor/one_hot_v2.h @@ -22,13 +22,23 @@ namespace paddle2onnx { class OneHotV2Mapper : public Mapper { public: - OneHotV2Mapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id, + OneHotV2Mapper(const PaddleParser& p, + OnnxHelper* helper, + int64_t block_id, int64_t op_id) : Mapper(p, helper, block_id, op_id) { GetAttr("allow_out_of_range", &allow_out_of_range_); GetAttr("depth", &depth_); GetAttr("dtype", &dtype_); } + + OneHotV2Mapper(const PaddlePirParser& p, + OnnxHelper* helper, + int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + allow_out_of_range_ = false; + } int32_t GetMinOpsetVersion(bool verbose) override; void Opset9() override; diff --git a/paddle2onnx/mapper/tensor/put_along_axis.cc b/paddle2onnx/mapper/tensor/put_along_axis.cc new file mode 100644 index 000000000..d1b962ed2 --- /dev/null +++ b/paddle2onnx/mapper/tensor/put_along_axis.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle2onnx/mapper/tensor/put_along_axis.h" +#include +#include +#include + +namespace paddle2onnx { +REGISTER_PIR_MAPPER(put_along_axis, PutAlongAxisMapper) + +int32_t PutAlongAxisMapper::GetMinOpsetVersion(bool verbose) { + if (reduce_ != "assign" && !include_self_) { + Error() << "Only support artt 'include_self' == 'false' when attr 'reduce' " + "== 'assign'" + << std::endl; + return -1; + } + if (reduce_ == "mean") { + Error() << "Not support artt 'reduce' == 'mean' yet." << std::endl; + return -1; + } + if (reduce_ == "amin" || reduce_ == "amax") { + Logger(verbose, 18) << RequireOpset(18) << std::endl; + return 18; + } + if (reduce_ != "assign") { + Logger(verbose, 16) << RequireOpset(16) << std::endl; + return 16; + } + Logger(verbose, 11) << RequireOpset(11) << std::endl; + return 11; +} + +void PutAlongAxisMapper::Opset11() { + auto arr_info = GetInput("arr"); + auto indices_info = GetInput("indices"); + auto values_info = GetInput("values"); + auto out_info = GetOutput("out"); + auto node = helper_->MakeNode( + "ScatterElements", + {arr_info[0].name, indices_info[0].name, values_info[0].name}, + {out_info[0].name}); + AddAttribute(node, "axis", axis_); +} + +void PutAlongAxisMapper::Opset16() { + auto arr_info = GetInput("arr"); + auto indices_info = GetInput("indices"); + auto values_info = GetInput("values"); + auto out_info = GetOutput("out"); + auto node = helper_->MakeNode( + "ScatterElements", + {arr_info[0].name, indices_info[0].name, values_info[0].name}, + {out_info[0].name}); + std::string onnx_reduction = "none"; + if (reduce_ == "assign") { + onnx_reduction = "none"; + } else if (reduce_ == "add") { + onnx_reduction = "add"; + } else if (reduce_ == "multiply") { + onnx_reduction = "mul"; + } + AddAttribute(node, "axis", axis_); + AddAttribute(node, "reduction", onnx_reduction); +} +void PutAlongAxisMapper::Opset18() { + auto arr_info = GetInput("arr"); + auto indices_info = GetInput("indices"); + auto values_info = GetInput("values"); + auto out_info = GetOutput("out"); + auto node = helper_->MakeNode( + "ScatterElements", + {arr_info[0].name, indices_info[0].name, values_info[0].name}, + {out_info[0].name}); + std::string onnx_reduction = "none"; + if (reduce_ == "assign") { + onnx_reduction = "none"; + } else if (reduce_ == "add") { + onnx_reduction = "add"; + } else if (reduce_ == "multiply") { + onnx_reduction = "mul"; + } else if (reduce_ == "amin") { + onnx_reduction = "min"; + } else if (reduce_ == "amax") { + onnx_reduction = "max"; + } + AddAttribute(node, "axis", axis_); + AddAttribute(node, "reduction", onnx_reduction); +} +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/put_along_axis.h b/paddle2onnx/mapper/tensor/put_along_axis.h new file mode 100644 index 000000000..840ed92fb --- /dev/null +++ b/paddle2onnx/mapper/tensor/put_along_axis.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle2onnx/mapper/mapper.h" + +namespace paddle2onnx { + +class PutAlongAxisMapper : public Mapper { + public: + PutAlongAxisMapper(const PaddlePirParser& p, + OnnxHelper* helper, + int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + GetAttr("axis", &axis_); + GetAttr("reduce", &reduce_); + GetAttr("include_self", &include_self_); + } + + int32_t GetMinOpsetVersion(bool verbose) override; + void Opset11() override; + void Opset16() override; + void Opset18() override; + + private: + int64_t axis_; + std::string reduce_; + bool include_self_; +}; +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/repeat_interleave.cc b/paddle2onnx/mapper/tensor/repeat_interleave.cc index 6cf4dd31a..2b6d0e70d 100644 --- a/paddle2onnx/mapper/tensor/repeat_interleave.cc +++ b/paddle2onnx/mapper/tensor/repeat_interleave.cc @@ -35,8 +35,6 @@ namespace paddle2onnx { if (in_pir_mode) { if (OpType() == "pd_op.repeat_interleave") { GetAttr("repeats", &repeat); - } else { - TryGetInputValue("repeats", &repeats); } } else { GetAttr("Repeats", &repeat); diff --git a/paddle2onnx/mapper/tensor/repeat_interleave.h b/paddle2onnx/mapper/tensor/repeat_interleave.h index d6a0e047c..12cc68fe2 100644 --- a/paddle2onnx/mapper/tensor/repeat_interleave.h +++ b/paddle2onnx/mapper/tensor/repeat_interleave.h @@ -30,7 +30,7 @@ class RepeatInterleaveMapper : public Mapper { int64_t op_id, bool c) : Mapper(p, helper, op_id, c) { in_pir_mode = true; - GetAttr("axis", &dim_); + GetAttr("dim", &dim_); } void Opset9() override; diff --git a/paddle2onnx/mapper/tensor/squeeze2.cc b/paddle2onnx/mapper/tensor/squeeze2.cc index 7383a91f4..f82519228 100644 --- a/paddle2onnx/mapper/tensor/squeeze2.cc +++ b/paddle2onnx/mapper/tensor/squeeze2.cc @@ -20,9 +20,9 @@ REGISTER_PIR_MAPPER(squeeze, Squeeze2Mapper) int32_t Squeeze2Mapper::GetMinOpsetVersion(bool verbose) { if (in_pir_mode) { - if (HasInput("axis")) { - return 13; - } + // if (HasInput("axis")) { + // return 13; + // } return 7; } diff --git a/paddle2onnx/mapper/tensor/unbind.cc b/paddle2onnx/mapper/tensor/unbind.cc index ee66c4914..71ded44fd 100644 --- a/paddle2onnx/mapper/tensor/unbind.cc +++ b/paddle2onnx/mapper/tensor/unbind.cc @@ -1,29 +1,45 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle2onnx/mapper/tensor/unbind.h" namespace paddle2onnx { REGISTER_MAPPER(unbind, UnbindMapper) +REGISTER_PIR_MAPPER(unbind, UnbindMapper) void UnbindMapper::Opset7() { auto input_info = GetInput("X"); auto output_info = GetOutput("Out"); - + std::vector output_names(output_info.size()); for (size_t i = 0; i < output_info.size(); ++i) { output_names[i] = output_info[i].name; } - int64_t split_axis = axis_; if (split_axis < 0) { split_axis += input_info[0].Rank(); } - - std::vector split_sizes = std::vector(input_info[0].shape[split_axis],1); - helper_->Split(input_info[0].name, output_names, split_sizes, split_axis); + + std::vector split_sizes = + std::vector(input_info[0].shape[split_axis], 1); + auto split_output_names = + helper_->Split(input_info[0].name, split_sizes, split_axis); for (size_t i = 0; i < output_info.size(); ++i) { - std::vector axes{split_axis}; - helper_->Squeeze(output_names[i], output_names[i], axes); + std::vector axes{split_axis}; + helper_->Squeeze(split_output_names[i], output_names[i], axes); } } -} \ No newline at end of file +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/unbind.h b/paddle2onnx/mapper/tensor/unbind.h index dd11a0aaf..71b9d9750 100644 --- a/paddle2onnx/mapper/tensor/unbind.h +++ b/paddle2onnx/mapper/tensor/unbind.h @@ -1,3 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include #include @@ -8,13 +22,23 @@ namespace paddle2onnx { class UnbindMapper : public Mapper { public: - UnbindMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id, - int64_t op_id) + UnbindMapper(const PaddleParser& p, + OnnxHelper* helper, + int64_t block_id, + int64_t op_id) : Mapper(p, helper, block_id, op_id) { - GetAttr("axis", &axis_); - } + GetAttr("axis", &axis_); + } + + UnbindMapper(const PaddlePirParser& p, + OnnxHelper* helper, + int64_t op_id, + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { + GetAttr("axis", &axis_); + } void Opset7() override; int64_t axis_; }; -} // namespace paddle2onnx \ No newline at end of file +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/where.h b/paddle2onnx/mapper/tensor/where.h index a7fb70325..bd348c387 100644 --- a/paddle2onnx/mapper/tensor/where.h +++ b/paddle2onnx/mapper/tensor/where.h @@ -22,14 +22,16 @@ namespace paddle2onnx { class WhereMapper : public Mapper { public: - WhereMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id, + WhereMapper(const PaddleParser& p, + OnnxHelper* helper, + int64_t block_id, int64_t op_id) : Mapper(p, helper, block_id, op_id) {} WhereMapper(const PaddlePirParser& p, OnnxHelper* helper, int64_t op_id, - bool c) - : Mapper(p, helper, op_id) { + bool if_in_cf_block) + : Mapper(p, helper, op_id, if_in_cf_block) { in_pir_mode = true; } diff --git a/paddle2onnx/parser/pir_parser.cc b/paddle2onnx/parser/pir_parser.cc index 202d1f3d3..0e8c35ba8 100644 --- a/paddle2onnx/parser/pir_parser.cc +++ b/paddle2onnx/parser/pir_parser.cc @@ -147,7 +147,7 @@ void PaddlePirParser::GetGlobalBlockOutputValueName() { void PaddlePirParser::GetAllSubBlockOpOutputName( std::vector block_op_lists) const { for (auto op : block_op_lists) { - std::string new_name = "p2o.sub_block" + op->name(); + std::string new_name = "p2o.sub_block." + op->name(); if (_name_counter.find(new_name) != _name_counter.end()) { _name_counter[new_name] += 1; } else { diff --git a/paddle2onnx/parser/pir_parser.h b/paddle2onnx/parser/pir_parser.h index 1c2b196cc..dbda5a0dc 100644 --- a/paddle2onnx/parser/pir_parser.h +++ b/paddle2onnx/parser/pir_parser.h @@ -130,7 +130,8 @@ class PaddlePirParser { std::string attr_value = "value"; std::string attr_values = "values"; pir::Operation* op = temp_op->operand(input_idx).source().defining_op(); - while (!op->HasAttribute(attr_value) && !op->HasAttribute(attr_values)) { + while (op->num_operands() > 0 && !op->HasAttribute(attr_value) && + !op->HasAttribute(attr_values)) { op = op->operand(0).source().defining_op(); } if (op->HasAttribute(attr_value)) { @@ -210,7 +211,8 @@ class PaddlePirParser { std::string attr_name; std::string attr_value = "value"; std::string attr_values = "values"; - while(!op->HasAttribute(attr_value) && !op->HasAttribute(attr_values)) { + while (op->num_operands() > 0 && !op->HasAttribute(attr_value) && + !op->HasAttribute(attr_values)) { op = op->operand(0).source().defining_op(); } if (op->HasAttribute(attr_value)) { @@ -257,6 +259,7 @@ class PaddlePirParser { bool if_in_sub_block, std::string tensor_arr_name) const; std::string GetTensorArrayName(int64_t op_id, bool if_in_sub_block) const; + std::string GenOpInputOutputName(const std::string& name) const; private: bool IsAttrVar(const pir::Operation* op, const int64_t& attr_id) const; @@ -268,7 +271,6 @@ class PaddlePirParser { void GetGlobalBlockInputValueName(); void GetGlobalBlockOutputValueName(); void GetAllOpOutputName(); - std::string GenOpInputOutputName(const std::string& name) const; void AddOpOutputName(pir::Operation* op, std::string var_name, int64_t output_idx) const; diff --git a/tests/test_argsort.py b/tests/test_argsort.py index f3e0d9677..8260d6338 100644 --- a/tests/test_argsort.py +++ b/tests/test_argsort.py @@ -14,7 +14,7 @@ import paddle from onnxbase import APIOnnx -from onnxbase import randtool +from onnxbase import randtool, _test_with_pir class Net(paddle.nn.Layer): @@ -35,6 +35,7 @@ def forward(self, inputs): return x +@_test_with_pir def test_argsort_11(): """ api: paddle.argsort @@ -43,13 +44,15 @@ def test_argsort_11(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'argsort', [11]) + obj = APIOnnx(op, "argsort", [11]) obj.set_input_data( "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype("float32")), + ) obj.run() +@_test_with_pir def test_argsort_12(): """ api: paddle.argsort @@ -58,13 +61,15 @@ def test_argsort_12(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'argsort', [12]) + obj = APIOnnx(op, "argsort", [12]) obj.set_input_data( "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype("float32")), + ) obj.run() +@_test_with_pir def test_argsort_axis(): """ api: paddle.argsort @@ -73,13 +78,15 @@ def test_argsort_axis(): op = Net(axis=1) op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'argsort', [12]) + obj = APIOnnx(op, "argsort", [12]) obj.set_input_data( "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype("float32")), + ) obj.run() +@_test_with_pir def test_argsort_descending(): """ api: paddle.argsort @@ -88,14 +95,15 @@ def test_argsort_descending(): op = Net(descending=True) op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'argsort', [12]) + obj = APIOnnx(op, "argsort", [12]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 10]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 10]).astype("float32")), + ) obj.run() +@_test_with_pir def test_argsort_descending_1(): """ api: paddle.argsort @@ -104,14 +112,15 @@ def test_argsort_descending_1(): op = Net(descending=True) op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'argsort', [7]) + obj = APIOnnx(op, "argsort", [7]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 10]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 10]).astype("float32")), + ) obj.run() +@_test_with_pir def test_argsort_descending_1_axis(): """ api: paddle.argsort @@ -120,9 +129,9 @@ def test_argsort_descending_1_axis(): op = Net(descending=True, axis=1) op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'argsort', [7]) + obj = APIOnnx(op, "argsort", [7]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 10]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 10]).astype("float32")), + ) obj.run() diff --git a/tests/test_auto_scan_argsort.py b/tests/test_auto_scan_argsort.py index ab814ef20..8858f859c 100644 --- a/tests/test_auto_scan_argsort.py +++ b/tests/test_auto_scan_argsort.py @@ -13,11 +13,11 @@ # limitations under the License. from auto_scan_test import OPConvertAutoScanTest, BaseNet -from hypothesis import reproduce_failure import hypothesis.strategies as st import numpy as np import unittest import paddle +from onnxbase import _test_with_pir class Net(BaseNet): @@ -31,9 +31,8 @@ def forward(self, input): """ x = paddle.argsort( - input, - axis=self.config['axis'], - descending=self.config['descending']) + input, axis=self.config["axis"], descending=self.config["descending"] + ) return x @@ -45,20 +44,19 @@ class TestArgsortConvert(OPConvertAutoScanTest): def sample_convert_config(self, draw): input_shape = draw( - st.lists( - st.integers( - min_value=2, max_value=5), min_size=2, max_size=5)) + st.lists(st.integers(min_value=2, max_value=5), min_size=2, max_size=5) + ) axis = draw( - st.integers( - min_value=-len(input_shape), max_value=len(input_shape) - 1)) + st.integers(min_value=-len(input_shape), max_value=len(input_shape) - 1) + ) dtype = draw(st.sampled_from(["float32", "float64"])) descending = draw(st.booleans()) def generator_data(): import random - import numpy as np + t = 1 for i in range(len(input_shape)): t = t * input_shape[i] @@ -84,6 +82,7 @@ def generator_data(): return (config, models) + @_test_with_pir def test(self): self.run_and_statis(max_examples=30) diff --git a/tests/test_auto_scan_one_hot_v2.py b/tests/test_auto_scan_one_hot_v2.py index 61007f43e..4079f5509 100755 --- a/tests/test_auto_scan_one_hot_v2.py +++ b/tests/test_auto_scan_one_hot_v2.py @@ -13,12 +13,11 @@ # limitations under the License. from auto_scan_test import OPConvertAutoScanTest, BaseNet -from hypothesis import reproduce_failure import hypothesis.strategies as st -import numpy as np import unittest from onnxbase import randtool import paddle +from onnxbase import _test_only_pir class Net(BaseNet): @@ -45,9 +44,8 @@ class TestOneHotV2Convert(OPConvertAutoScanTest): def sample_convert_config(self, draw): input_shape = draw( - st.lists( - st.integers( - min_value=10, max_value=20), min_size=1, max_size=4)) + st.lists(st.integers(min_value=10, max_value=20), min_size=1, max_size=4) + ) num_classes = draw(st.integers(min_value=10, max_value=20)) @@ -65,13 +63,14 @@ def generator_data(): "opset_version": [9, 13, 15], "input_spec_shape": [], "num_classes": num_classes, - "is_tensor": is_tensor + "is_tensor": is_tensor, } models = Net(config) return (config, models) + @_test_only_pir def test(self): self.run_and_statis(max_examples=30) diff --git a/tests/test_auto_scan_pad3d.py b/tests/test_auto_scan_pad3d.py index ca7f221fe..e41e9a0ad 100755 --- a/tests/test_auto_scan_pad3d.py +++ b/tests/test_auto_scan_pad3d.py @@ -13,9 +13,8 @@ # limitations under the License. from auto_scan_test import OPConvertAutoScanTest, BaseNet -from hypothesis import reproduce_failure import hypothesis.strategies as st -from onnxbase import randtool +from onnxbase import _test_only_pir import numpy as np import unittest import paddle @@ -27,11 +26,9 @@ def forward(self, inputs): mode = self.config["mode"] value = self.config["value"] data_format = self.config["data_format"] - x = paddle.nn.functional.pad(inputs, - pad=pad, - mode=mode, - value=value, - data_format=data_format) + x = paddle.nn.functional.pad( + inputs, pad=pad, mode=mode, value=value, data_format=data_format + ) shape = paddle.shape(x) x = paddle.reshape(x, shape) @@ -46,13 +43,12 @@ class TestPadopsConvert(OPConvertAutoScanTest): def sample_convert_config(self, draw): input_shape = draw( - st.lists( - st.integers( - min_value=10, max_value=20), min_size=4, max_size=5)) + st.lists(st.integers(min_value=10, max_value=20), min_size=4, max_size=5) + ) dtype = "float32" - mode = draw(st.sampled_from(["constant", "reflect", "replicate"])) + mode = draw(st.sampled_from(["constant", "reflect", "replicate", "circular"])) value = draw(st.floats(min_value=0, max_value=10)) @@ -70,25 +66,16 @@ def sample_convert_config(self, draw): pad = None if len(input_shape) == 3: pad = draw( - st.lists( - st.integers( - min_value=0, max_value=4), - min_size=2, - max_size=2)) + st.lists(st.integers(min_value=0, max_value=4), min_size=2, max_size=2) + ) elif len(input_shape) == 4: pad = draw( - st.lists( - st.integers( - min_value=0, max_value=4), - min_size=4, - max_size=4)) + st.lists(st.integers(min_value=0, max_value=4), min_size=4, max_size=4) + ) else: pad = draw( - st.lists( - st.integers( - min_value=0, max_value=4), - min_size=6, - max_size=6)) + st.lists(st.integers(min_value=0, max_value=4), min_size=6, max_size=6) + ) config = { "op_names": ["pad3d"], @@ -99,13 +86,16 @@ def sample_convert_config(self, draw): "mode": mode, "value": value, "pad": pad, - "data_format": data_format + "data_format": data_format, } + if mode == "circular": + config["opset_version"] = [19] model = Net(config) return (config, model) + @_test_only_pir def test(self): self.run_and_statis(max_examples=25, max_duration=-1) @@ -113,15 +103,13 @@ def test(self): class Net2(BaseNet): def forward(self, inputs): data = np.ones(shape=[6], dtype="int32") - pad = paddle.to_tensor(data, dtype='int32') + pad = paddle.to_tensor(data, dtype="int32") mode = self.config["mode"] value = self.config["value"] data_format = self.config["data_format"] - x = paddle.nn.functional.pad(inputs, - pad, - mode=mode, - value=value, - data_format=data_format) + x = paddle.nn.functional.pad( + inputs, pad, mode=mode, value=value, data_format=data_format + ) shape = paddle.shape(x) x = paddle.reshape(x, shape) @@ -136,24 +124,22 @@ class TestPadopsConvert_Constanttensor(OPConvertAutoScanTest): def sample_convert_config(self, draw): input_shape = draw( - st.lists( - st.integers( - min_value=4, max_value=10), min_size=5, max_size=5)) + st.lists(st.integers(min_value=4, max_value=10), min_size=5, max_size=5) + ) dtype = "float32" - mode = draw(st.sampled_from(["constant", "reflect", "replicate"])) + mode = draw(st.sampled_from(["constant", "reflect", "replicate", "circular"])) value = draw(st.floats(min_value=0, max_value=10)) data_format = None - #data_format = draw(st.sampled_from(["NCDHW", "NDHWC"])) + # data_format = draw(st.sampled_from(["NCDHW", "NDHWC"])) data_format = "NCDHW" pad = draw( - st.lists( - st.integers( - min_value=0, max_value=4), min_size=6, max_size=6)) + st.lists(st.integers(min_value=0, max_value=4), min_size=6, max_size=6) + ) config = { "op_names": ["pad3d"], @@ -164,13 +150,16 @@ def sample_convert_config(self, draw): "mode": mode, "value": value, "pad": pad, - "data_format": data_format + "data_format": data_format, } + if mode == "circular": + config["opset_version"] = [19] model = Net2(config) return (config, model) + @_test_only_pir def test(self): self.run_and_statis(max_examples=25, max_duration=-1) diff --git a/tests/test_auto_scan_prelu.py b/tests/test_auto_scan_prelu.py index be6e929c2..c4b3d6f8f 100755 --- a/tests/test_auto_scan_prelu.py +++ b/tests/test_auto_scan_prelu.py @@ -13,11 +13,10 @@ # limitations under the License. from auto_scan_test import OPConvertAutoScanTest, BaseNet -from hypothesis import reproduce_failure import hypothesis.strategies as st -import numpy as np import unittest import paddle +from onnxbase import _test_with_pir class Net(BaseNet): @@ -41,9 +40,8 @@ class TestPreluConvert(OPConvertAutoScanTest): def sample_convert_config(self, draw): input_shape = draw( - st.lists( - st.integers( - min_value=5, max_value=20), min_size=0, max_size=4)) + st.lists(st.integers(min_value=5, max_value=20), min_size=0, max_size=4) + ) if len(input_shape) == 0: weight_shape = [] else: @@ -63,6 +61,7 @@ def sample_convert_config(self, draw): return (config, models) + @_test_with_pir def test(self): self.run_and_statis(max_examples=30) diff --git a/tests/test_auto_scan_put_along_axis.py b/tests/test_auto_scan_put_along_axis.py new file mode 100644 index 000000000..e70d2b23f --- /dev/null +++ b/tests/test_auto_scan_put_along_axis.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import OPConvertAutoScanTest, BaseNet +import hypothesis.strategies as st +import unittest +import paddle +from onnxbase import _test_only_pir, randtool + + +class Net(BaseNet): + """ + simple Net + """ + + def forward(self, arr, indices, values): + """ + forward + """ + x = paddle.put_along_axis( + arr, indices, values, axis=self.config["axis"], reduce=self.config["reduce"] + ) + return x + + +class TestPutAlongAxisConvert(OPConvertAutoScanTest): + """ + api: paddle.put_along_axis + OPset version: 11, 16, 18 + """ + + def sample_convert_config(self, draw): + input_shape = draw( + st.lists(st.integers(min_value=1, max_value=20), min_size=2, max_size=5) + ) + dtype = draw(st.sampled_from(["float32", "float64"])) + dtype2 = draw(st.sampled_from(["int32", "int64"])) + # dtype3 = draw(st.sampled_from(["float32", "float64"])) + axis = draw(st.integers(min_value=0, max_value=len(input_shape) - 1)) + reduce = draw(st.sampled_from(["assign", "add", "multiply", "amin", "amax"])) + + opset_version = [] + if reduce == "add" or reduce == "multiply": + opset_version.append(16) + elif reduce == "amin" or reduce == "amax": + opset_version.append(18) + else: + opset_version.append(11) + + def generator_data(): + input_data = randtool("int", 0, input_shape[axis], input_shape) + print("wmk" * 10) + print(input_data.shape) + return input_data + + config = { + "op_names": ["put_along_axis"], + "test_data_shapes": [input_shape, generator_data, input_shape], + "test_data_types": [[dtype], [dtype2], [dtype]], + "opset_version": opset_version, + "input_spec_shape": [], + "axis": axis, + "reduce": reduce, + } + + models = Net(config) + + return (config, models) + + @_test_only_pir + def test(self): + self.run_and_statis(max_examples=30) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bitwise.py b/tests/test_bitwise.py index b898a9c8d..c93e0b84e 100644 --- a/tests/test_bitwise.py +++ b/tests/test_bitwise.py @@ -14,15 +14,19 @@ import paddle from onnxbase import APIOnnx -from onnxbase import randtool +from onnxbase import _test_with_pir + class BitwiseAndNet(paddle.nn.Layer): def __init__(self): super(BitwiseAndNet, self).__init__() + def forward(self, x, y): x = paddle.bitwise_and(x, y) return x + +@_test_with_pir def test_bitwise_and_int_type_18(): """ api: paddle.bitwise_and @@ -31,13 +35,14 @@ def test_bitwise_and_int_type_18(): op = BitwiseAndNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseAnd', [18]) + obj = APIOnnx(op, "BitwiseAnd", [18]) obj.set_input_data( - "input_data", - paddle.to_tensor([-5, -1, 1]), - paddle.to_tensor([4, 2, -3])) + "input_data", paddle.to_tensor([-5, -1, 1]), paddle.to_tensor([4, 2, -3]) + ) obj.run() + +@_test_with_pir def test_bitwise_and_bool_type(): """ api: paddle.bitwise_and @@ -46,13 +51,16 @@ def test_bitwise_and_bool_type(): op = BitwiseAndNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseAnd', [7]) + obj = APIOnnx(op, "BitwiseAnd", [7]) obj.set_input_data( "input_data", paddle.to_tensor([True, True, True]), - paddle.to_tensor([False, False, True])) + paddle.to_tensor([False, False, True]), + ) obj.run() + +@_test_with_pir def test_bitwise_and_bool_type_18(): """ api: paddle.bitwise_and @@ -61,23 +69,25 @@ def test_bitwise_and_bool_type_18(): op = BitwiseAndNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseAnd', [18]) + obj = APIOnnx(op, "BitwiseAnd", [18]) obj.set_input_data( "input_data", paddle.to_tensor([True, True, True]), - paddle.to_tensor([False, False, True])) + paddle.to_tensor([False, False, True]), + ) obj.run() - - class BitwiseNotNet(paddle.nn.Layer): def __init__(self): super(BitwiseNotNet, self).__init__() + def forward(self, x): x = paddle.bitwise_not(x) return x + +@_test_with_pir def test_bitwise_not_int_type_18(): """ api: paddle.bitwise_not @@ -86,12 +96,12 @@ def test_bitwise_not_int_type_18(): op = BitwiseNotNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseNot', [18]) - obj.set_input_data( - "input_data", - paddle.to_tensor([-5, -1, 1])) + obj = APIOnnx(op, "BitwiseNot", [18]) + obj.set_input_data("input_data", paddle.to_tensor([-5, -1, 1])) obj.run() + +@_test_with_pir def test_bitwise_not_bool_type(): """ api: paddle.bitwise_not @@ -100,13 +110,12 @@ def test_bitwise_not_bool_type(): op = BitwiseNotNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseNot', [7]) - obj.set_input_data( - "input_data", - paddle.to_tensor([True, True, True]) - ) + obj = APIOnnx(op, "BitwiseNot", [7]) + obj.set_input_data("input_data", paddle.to_tensor([True, True, True])) obj.run() + +@_test_with_pir def test_bitwise_not_bool_type_18(): """ api: paddle.bitwise_not @@ -115,22 +124,21 @@ def test_bitwise_not_bool_type_18(): op = BitwiseNotNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseNot', [18]) - obj.set_input_data( - "input_data", - paddle.to_tensor([True, True, True]) - ) + obj = APIOnnx(op, "BitwiseNot", [18]) + obj.set_input_data("input_data", paddle.to_tensor([True, True, True])) obj.run() - class BitwiseOrNet(paddle.nn.Layer): def __init__(self): super(BitwiseOrNet, self).__init__() + def forward(self, x, y): x = paddle.bitwise_or(x, y) return x + +@_test_with_pir def test_bitwise_or_int_type_18(): """ api: paddle.bitwise_or @@ -139,13 +147,14 @@ def test_bitwise_or_int_type_18(): op = BitwiseOrNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseOr', [18]) + obj = APIOnnx(op, "BitwiseOr", [18]) obj.set_input_data( - "input_data", - paddle.to_tensor([-5, -1, 1]), - paddle.to_tensor([4, 2, -3])) + "input_data", paddle.to_tensor([-5, -1, 1]), paddle.to_tensor([4, 2, -3]) + ) obj.run() + +@_test_with_pir def test_bitwise_or_bool_type(): """ api: paddle.bitwise_or @@ -154,13 +163,16 @@ def test_bitwise_or_bool_type(): op = BitwiseOrNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseOr', [7]) + obj = APIOnnx(op, "BitwiseOr", [7]) obj.set_input_data( "input_data", paddle.to_tensor([True, True, True]), - paddle.to_tensor([False, False, True])) + paddle.to_tensor([False, False, True]), + ) obj.run() + +@_test_with_pir def test_bitwise_or_bool_type_18(): """ api: paddle.bitwise_or @@ -169,22 +181,25 @@ def test_bitwise_or_bool_type_18(): op = BitwiseOrNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseOr', [18]) + obj = APIOnnx(op, "BitwiseOr", [18]) obj.set_input_data( "input_data", paddle.to_tensor([True, True, True]), - paddle.to_tensor([False, False, True])) + paddle.to_tensor([False, False, True]), + ) obj.run() - class BitwiseXorNet(paddle.nn.Layer): def __init__(self): super(BitwiseXorNet, self).__init__() + def forward(self, x, y): x = paddle.bitwise_xor(x, y) return x + +@_test_with_pir def test_bitwise_xor_int_type_18(): """ api: paddle.bitwise_xor @@ -193,13 +208,14 @@ def test_bitwise_xor_int_type_18(): op = BitwiseXorNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseXor', [18]) + obj = APIOnnx(op, "BitwiseXor", [18]) obj.set_input_data( - "input_data", - paddle.to_tensor([-5, -1, 1]), - paddle.to_tensor([4, 2, -3])) + "input_data", paddle.to_tensor([-5, -1, 1]), paddle.to_tensor([4, 2, -3]) + ) obj.run() + +@_test_with_pir def test_bitwise_xor_bool_type(): """ api: paddle.bitwise_xor @@ -208,13 +224,16 @@ def test_bitwise_xor_bool_type(): op = BitwiseXorNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseXor', [7]) + obj = APIOnnx(op, "BitwiseXor", [7]) obj.set_input_data( "input_data", paddle.to_tensor([True, True, True]), - paddle.to_tensor([False, False, True])) + paddle.to_tensor([False, False, True]), + ) obj.run() + +@_test_with_pir def test_bitwise_xor_bool_type_18(): """ api: paddle.bitwise_xor @@ -223,12 +242,14 @@ def test_bitwise_xor_bool_type_18(): op = BitwiseXorNet() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'BitwiseXor', [18]) + obj = APIOnnx(op, "BitwiseXor", [18]) obj.set_input_data( "input_data", paddle.to_tensor([True, True, True]), - paddle.to_tensor([False, False, True])) + paddle.to_tensor([False, False, True]), + ) obj.run() -if __name__ == '__main__': - test_bitwise_not_int_type_18() \ No newline at end of file + +if __name__ == "__main__": + test_bitwise_not_int_type_18() diff --git a/tests/test_ceil.py b/tests/test_ceil.py index b7773aff0..3ac5033f1 100644 --- a/tests/test_ceil.py +++ b/tests/test_ceil.py @@ -14,7 +14,7 @@ import paddle from onnxbase import APIOnnx -from onnxbase import randtool +from onnxbase import randtool, _test_with_pir class Net(paddle.nn.Layer): @@ -33,6 +33,7 @@ def forward(self, inputs): return x +@_test_with_pir def test_ceil_9(): """ api: paddle.ceil @@ -41,13 +42,15 @@ def test_ceil_9(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'ceil', [9]) + obj = APIOnnx(op, "ceil", [9]) obj.set_input_data( "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3]).astype("float32")), + ) obj.run() +@_test_with_pir def test_ceil_10(): """ api: paddle.ceil @@ -56,13 +59,15 @@ def test_ceil_10(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'ceil', [10]) + obj = APIOnnx(op, "ceil", [10]) obj.set_input_data( "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3]).astype("float32")), + ) obj.run() +@_test_with_pir def test_ceil_11(): """ api: paddle.ceil @@ -71,13 +76,15 @@ def test_ceil_11(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'ceil', [11]) + obj = APIOnnx(op, "ceil", [11]) obj.set_input_data( "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3]).astype("float32")), + ) obj.run() +@_test_with_pir def test_ceil_12(): """ api: paddle.ceil @@ -86,8 +93,9 @@ def test_ceil_12(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'ceil', [12]) + obj = APIOnnx(op, "ceil", [12]) obj.set_input_data( "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3]).astype("float32")), + ) obj.run() diff --git a/tests/test_nn_Embedding.py b/tests/test_nn_Embedding.py index c2afd84ee..05614a8d6 100644 --- a/tests/test_nn_Embedding.py +++ b/tests/test_nn_Embedding.py @@ -14,7 +14,7 @@ import paddle from onnxbase import APIOnnx -from onnxbase import randtool +from onnxbase import _test_with_pir import numpy as np @@ -31,7 +31,8 @@ def __init__(self): padding_idx=None, sparse=True, weight_attr=None, - name=None) + name=None, + ) def forward(self, inputs): """ @@ -41,6 +42,7 @@ def forward(self, inputs): return x +@_test_with_pir def test_Embedding_base(): """ api: paddle.Embedding @@ -49,8 +51,8 @@ def test_Embedding_base(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'nn_Embedding', [9, 10, 11, 12]) + obj = APIOnnx(op, "nn_Embedding", [9, 10, 11, 12]) obj.set_input_data( - "input_data", - paddle.to_tensor(np.arange(3, 6).reshape((3, 1)).astype(np.int64))) + "input_data", paddle.to_tensor(np.arange(3, 6).reshape((3, 1)).astype(np.int64)) + ) obj.run() diff --git a/tests/test_nn_Pad3D.py b/tests/test_nn_Pad3D.py index 3aac71b59..60fa7d3bc 100644 --- a/tests/test_nn_Pad3D.py +++ b/tests/test_nn_Pad3D.py @@ -14,7 +14,7 @@ import paddle from onnxbase import APIOnnx -from onnxbase import randtool +from onnxbase import randtool, _test_with_pir class Net(paddle.nn.Layer): @@ -24,7 +24,7 @@ class Net(paddle.nn.Layer): def __init__(self): super(Net, self).__init__() - self._pad = paddle.nn.Pad3D(padding=1, mode='constant') + self._pad = paddle.nn.Pad3D(padding=1, mode="constant") def forward(self, inputs): """ @@ -34,6 +34,7 @@ def forward(self, inputs): return x +@_test_with_pir def test_Pad3D_9(): """ api: paddle.Pad3D @@ -42,14 +43,17 @@ def test_Pad3D_9(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'nn_Pad3D', [9]) + obj = APIOnnx(op, "nn_Pad3D", [9]) obj.set_input_data( "input_data", paddle.to_tensor( - randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype('float32'))) + randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype("float32") + ), + ) obj.run() +@_test_with_pir def test_Pad3D_10(): """ api: paddle.nn.Pad3D @@ -58,14 +62,17 @@ def test_Pad3D_10(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'nn_Pad3D', [10]) + obj = APIOnnx(op, "nn_Pad3D", [10]) obj.set_input_data( "input_data", paddle.to_tensor( - randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype('float32'))) + randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype("float32") + ), + ) obj.run() +@_test_with_pir def test_Pad3D_11(): """ api: paddle.nn.Pad3D @@ -74,14 +81,17 @@ def test_Pad3D_11(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'nn_Pad3D', [11]) + obj = APIOnnx(op, "nn_Pad3D", [11]) obj.set_input_data( "input_data", paddle.to_tensor( - randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype('float32'))) + randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype("float32") + ), + ) obj.run() +@_test_with_pir def test_Pad3D_12(): """ api: paddle.nn.Pad3D @@ -90,9 +100,47 @@ def test_Pad3D_12(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'nn_Pad3D', [12]) + obj = APIOnnx(op, "nn_Pad3D", [12]) obj.set_input_data( "input_data", paddle.to_tensor( - randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype('float32'))) + randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype("float32") + ), + ) + obj.run() + + +class Net2(paddle.nn.Layer): + """ + simple Net + """ + + def __init__(self): + super(Net2, self).__init__() + self._pad = paddle.nn.Pad3D(padding=1, mode="circular") + + def forward(self, inputs): + """ + forward + """ + x = self._pad(inputs) + return x + + +@_test_with_pir +def test_Pad3D_19(): + """ + api: paddle.nn.Pad3D + op version: 19 + """ + op = Net2() + op.eval() + # net, name, ver_list, delta=1e-6, rtol=1e-5 + obj = APIOnnx(op, "nn_Pad3D", [19]) + obj.set_input_data( + "input_data", + paddle.to_tensor( + randtool("float", -1, 1, [3, 1, 10, 10, 10]).astype("float32") + ), + ) obj.run() diff --git a/tests/test_prelu.py b/tests/test_prelu.py index 9c1bad9cf..c2fa420d1 100644 --- a/tests/test_prelu.py +++ b/tests/test_prelu.py @@ -15,6 +15,7 @@ import paddle from onnxbase import APIOnnx from onnxbase import randtool +from onnxbase import _test_only_pir class Net(paddle.nn.Layer): @@ -33,6 +34,7 @@ def forward(self, inputs): return x +@_test_only_pir def test_prelu_9(): """ api: paddle.prelu @@ -41,14 +43,15 @@ def test_prelu_9(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'prelu', [9]) + obj = APIOnnx(op, "prelu", [9]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), + ) obj.run() +@_test_only_pir def test_prelu_10(): """ api: paddle.prelu @@ -57,14 +60,15 @@ def test_prelu_10(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'prelu', [10]) + obj = APIOnnx(op, "prelu", [10]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), + ) obj.run() +@_test_only_pir def test_prelu_11(): """ api: paddle.prelu @@ -73,14 +77,15 @@ def test_prelu_11(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'prelu', [11]) + obj = APIOnnx(op, "prelu", [11]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), + ) obj.run() +@_test_only_pir def test_prelu_12(): """ api: paddle.prelu @@ -89,9 +94,9 @@ def test_prelu_12(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'prelu', [12]) + obj = APIOnnx(op, "prelu", [12]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), + ) obj.run() diff --git a/tests/test_tanh.py b/tests/test_tanh.py index 9a3ccb514..5a2f40703 100644 --- a/tests/test_tanh.py +++ b/tests/test_tanh.py @@ -14,7 +14,7 @@ import paddle from onnxbase import APIOnnx -from onnxbase import randtool +from onnxbase import randtool, _test_with_pir class Net(paddle.nn.Layer): @@ -33,6 +33,7 @@ def forward(self, inputs): return x +@_test_with_pir def test_tanh_9(): """ api: paddle.tanh @@ -41,14 +42,15 @@ def test_tanh_9(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'tanh', [9]) + obj = APIOnnx(op, "tanh", [9]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), + ) obj.run() +@_test_with_pir def test_tanh_10(): """ api: paddle.tanh @@ -57,14 +59,15 @@ def test_tanh_10(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'tanh', [10]) + obj = APIOnnx(op, "tanh", [10]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), + ) obj.run() +@_test_with_pir def test_tanh_11(): """ api: paddle.tanh @@ -73,14 +76,15 @@ def test_tanh_11(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'tanh', [11]) + obj = APIOnnx(op, "tanh", [11]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), + ) obj.run() +@_test_with_pir def test_tanh_12(): """ api: paddle.tanh @@ -89,9 +93,9 @@ def test_tanh_12(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'tanh', [12]) + obj = APIOnnx(op, "tanh", [12]) obj.set_input_data( "input_data", - paddle.to_tensor( - randtool("float", -1, 1, [3, 3, 3]).astype('float32'))) + paddle.to_tensor(randtool("float", -1, 1, [3, 3, 3]).astype("float32")), + ) obj.run() diff --git a/tests/test_unbind.py b/tests/test_unbind.py index d30de6649..3912181f0 100644 --- a/tests/test_unbind.py +++ b/tests/test_unbind.py @@ -1,5 +1,19 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle -from onnxbase import APIOnnx +from onnxbase import APIOnnx, _test_with_pir class Net(paddle.nn.Layer): @@ -19,6 +33,7 @@ def forward(self, inputs, axis=1): return x +@_test_with_pir def test_unbind(): """ api: paddle.unbind @@ -27,21 +42,11 @@ def test_unbind(): op = Net() op.eval() # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'unbind', [7]) - input_data = paddle.to_tensor([[ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ], - [ - [11, 22, 33], - [44, 55, 66], - [77, 88, 99] - ]]).astype('float32') + obj = APIOnnx(op, "unbind", [7]) + input_data = paddle.to_tensor( + [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[11, 22, 33], [44, 55, 66], [77, 88, 99]]] + ).astype("float32") print(input_data) - obj.set_input_data( - "input_data", - input_data - ) + obj.set_input_data("input_data", input_data) obj.run()