From e84d51f75894162887929a00daba932ca055a111 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Mon, 25 Nov 2024 10:44:24 +0000 Subject: [PATCH] Replace deprecated use of .cast/dyn_cast/isa with the mlir freestanding functions and misc cleanups. Also remove some unecessary semicolons and standardize to use mlir instead of llvm casts --- src/Conversion/ONNXToTOSA/DialectBuilder.cpp | 9 +++--- src/Conversion/ONNXToTOSA/Math/Conv2D.cpp | 2 +- .../ONNXToTOSA/Math/Elementwise.cpp | 32 +++++++++---------- src/Conversion/ONNXToTOSA/Math/MatMul.cpp | 2 +- src/Conversion/ONNXToTOSA/Math/Reduce.cpp | 12 +++---- src/Conversion/ONNXToTOSA/NN/AveragePool.cpp | 4 +-- .../ONNXToTOSA/NN/MaxPoolSingleOut.cpp | 2 +- .../ONNXToTOSA/ONNXToTOSACommon.cpp | 22 ++++++------- .../ONNXToTOSA/ONNXToTOSACommon.hpp | 4 +-- .../ONNXToTOSA/ONNXToTOSACommon.hpp.inc | 10 +++--- .../ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp | 2 +- .../ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp | 6 ++-- src/Conversion/ONNXToTOSA/Tensor/Concat.cpp | 2 +- src/Conversion/ONNXToTOSA/Tensor/Expand.cpp | 4 +-- src/Conversion/ONNXToTOSA/Tensor/Flatten.cpp | 2 +- src/Conversion/ONNXToTOSA/Tensor/Gather.cpp | 4 +-- .../ONNXToTOSA/Tensor/PaddingOp.cpp | 4 +-- src/Conversion/ONNXToTOSA/Tensor/Resize.cpp | 4 +-- .../ONNXToTOSA/Tensor/Transpose.cpp | 4 +-- 19 files changed, 64 insertions(+), 67 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp index ce262b4d06..60522c7303 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp @@ -176,7 +176,8 @@ std::optional TosaBuilder::gather(Value resultValue, Value inputValue, Value indicesValue, int32_t batchDims, int32_t axis) { return tosa::convertGatherOp(rewriter(), loc(), resultValue, inputValue, indicesValue, batchDims, axis); -}; +} + Value TosaBuilder::reshape(mlir::Value value, llvm::ArrayRef shape) { auto shapeAttr = rewriter().getDenseI64ArrayAttr(shape); auto valueType = mlir::cast(value.getType()); @@ -246,7 +247,7 @@ template Value TosaBuilder::binaryOp( template Value TosaBuilder::unaryOp(mlir::Value &input) { - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); Type newValueType = RankedTensorType::get( llvm::SmallVector(inputType.getRank(), ShapedType::kDynamic), inputType.getElementType()); @@ -304,7 +305,7 @@ Value TosaBuilder::select( lhs = valueVec[1]; rhs = valueVec[2]; } - auto lhsType = lhs.getType().cast(); + auto lhsType = cast(lhs.getType()); Type newValueType = RankedTensorType::get( llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), lhsType.getElementType()); @@ -326,7 +327,7 @@ mlir::Value TosaBuilder::castToNewTensorElementType( } Value TosaBuilder::sqrt(mlir::Value &input) { - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto oneHalf = this->getSplattedConst( 0.5, inputType.getElementType(), inputType.getShape()); return this->binaryOp(input, oneHalf); diff --git a/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp b/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp index 478323ab1b..85da6933a8 100644 --- a/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp @@ -101,7 +101,7 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { OpAdaptor adaptor(operands, op->getAttrDictionary()); auto loc = op->getLoc(); - auto convOp = llvm::cast(op); + auto convOp = mlir::cast(op); TosaBuilder tosaBuilder(rewriter, loc); diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 8ceba4260d..eeee6ab7f0 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -80,12 +80,12 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps( ConversionPatternRewriter &rewriter, Operation *op, OpAdaptorT adaptor, Type resultType) { Value lhs = adaptor.getOperands()[0]; - auto lhsType = lhs.getType().dyn_cast(); + auto lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOperands()[1]; - auto rhsType = rhs.getType().dyn_cast(); + auto rhsType = dyn_cast(rhs.getType()); - auto resultTensorType = resultType.dyn_cast(); + auto resultTensorType = dyn_cast(resultType); if (!lhsType || !rhsType || !resultTensorType) { return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes"); } @@ -121,9 +121,9 @@ class ONNXElementwiseUnaryOpLoweringToTOSA ConversionPatternRewriter &rewriter) const override { Value input = *adaptor.getODSOperands(0).begin(); - auto inputType = input.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); Value output = op.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!inputType || !outputType) { return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes"); @@ -248,7 +248,7 @@ class ONNXLeakyReluOpLoweringToTOSA LogicalResult matchAndRewrite(ONNXLeakyReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outputType = op.getResult().getType().cast(); + auto outputType = cast(op.getResult().getType()); if (failed(IsIntOrFloat::checkType( rewriter, outputType.getElementType(), op))) { return failure(); @@ -279,15 +279,13 @@ class ONNXComparisonOpLoweringToTOSA : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Value input1 = adaptor.getA(); - auto input1ElemType = - input1.getType().template cast().getElementType(); + auto input1ElemType = cast(input1.getType()).getElementType(); if (failed(IsIntOrFloat::checkType(rewriter, input1ElemType, op))) { return failure(); } Value input2 = adaptor.getB(); - auto input2ElemType = - input2.getType().template cast().getElementType(); + auto input2ElemType = cast(input2.getType()).getElementType(); if (input1ElemType != input2ElemType) { return failure(); } @@ -432,7 +430,7 @@ class ONNXSqrtOpLoweringToTOSA : public OpConversionPattern { LogicalResult matchAndRewrite(ONNXSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resultTensorType = op.getResult().getType().cast(); + auto resultTensorType = cast(op.getResult().getType()); if (failed(IsFloat::checkType( rewriter, resultTensorType.getElementType(), op))) { return failure(); @@ -454,7 +452,7 @@ class ONNXEluOpLoweringToTOSA : public OpConversionPattern { // ELU(x) = x if x >= 0 // alpha * (exp(x) - 1.) if x < 0 - auto resultTensorType = op.getResult().getType().cast(); + auto resultTensorType = cast(op.getResult().getType()); if (failed(IsFloat::checkType( rewriter, resultTensorType.getElementType(), op))) { return failure(); @@ -496,7 +494,7 @@ class ONNXHardSigmoidOpLoweringToTOSA // - tosa.mul(clamp, alpha) Value input = adaptor.getX(); - auto resultType = op.getResult().getType().template cast(); + auto resultType = cast(op.getResult().getType()); auto resultElementType = resultType.getElementType(); TosaBuilder tosaBuilder(rewriter, op->getLoc()); @@ -536,7 +534,7 @@ class ONNXPReluOpLoweringToTOSA : public OpConversionPattern { LogicalResult matchAndRewrite(ONNXPReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outputType = op.getResult().getType().cast(); + auto outputType = cast(op.getResult().getType()); if (failed(IsIntOrFloat::checkType( rewriter, outputType.getElementType(), op))) { return failure(); @@ -554,7 +552,7 @@ class ONNXSoftplusOpLoweringToTOSA LogicalResult matchAndRewrite(ONNXSoftplusOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outputType = op.getResult().getType().cast(); + auto outputType = cast(op.getResult().getType()); if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) { return failure(); } @@ -579,7 +577,7 @@ class ONNXSeluOpLoweringToTOSA : public OpConversionPattern { LogicalResult matchAndRewrite(ONNXSeluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outputType = op.getResult().getType().cast(); + auto outputType = cast(op.getResult().getType()); if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) { return failure(); } @@ -618,7 +616,7 @@ class ONNXThresholdedReluOpLoweringToTOSA LogicalResult matchAndRewrite(ONNXThresholdedReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outputType = op.getResult().getType().cast(); + auto outputType = cast(op.getResult().getType()); if (failed(IsIntOrFloat::checkType( rewriter, outputType.getElementType(), op))) { return failure(); diff --git a/src/Conversion/ONNXToTOSA/Math/MatMul.cpp b/src/Conversion/ONNXToTOSA/Math/MatMul.cpp index 0674af600f..219ed882d7 100644 --- a/src/Conversion/ONNXToTOSA/Math/MatMul.cpp +++ b/src/Conversion/ONNXToTOSA/Math/MatMul.cpp @@ -50,7 +50,7 @@ Value reshapeUpTo3DTensor(Value tensor, TosaBuilder &builder) { } return builder.reshape(tensor, newShape); -}; +} // Obtaining the rank broadcasted shapes of tensors makes it easier to // construct the input and output reshaping logic. diff --git a/src/Conversion/ONNXToTOSA/Math/Reduce.cpp b/src/Conversion/ONNXToTOSA/Math/Reduce.cpp index da111f98f6..bb567b771d 100644 --- a/src/Conversion/ONNXToTOSA/Math/Reduce.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Reduce.cpp @@ -43,7 +43,7 @@ DenseIntElementsAttr getAxesLatestsVersionAttr(ONNXReduceOp op) { if (noOpIfAxesEmpty == 0) { // Default behaviour when "axes" is none and "noop_with_empty_axes" is // set to false, it is to reduce all dims - const int64_t numberOfAxes = input.getType().cast().getRank(); + const int64_t numberOfAxes = cast(input.getType()).getRank(); auto iotaRange = llvm::iota_range(0, numberOfAxes, /*Inclusive=*/false); targetAxes = SmallVector(iotaRange.begin(), iotaRange.end()); @@ -86,7 +86,7 @@ DenseIntElementsAttr getAxesLegacyVersionAttr(ONNXReduceOp op) { SmallVector targetAxes; if (!axes) { // if not present all axes are reduced - const int64_t numberOfAxes = input.getType().cast().getRank(); + const int64_t numberOfAxes = cast(input.getType()).getRank(); auto iotaRange = llvm::iota_range(0, numberOfAxes, /*Inclusive=*/false); targetAxes = SmallVector(iotaRange.begin(), iotaRange.end()); @@ -111,14 +111,12 @@ class ONNXReduceOpLoweringToTOSA : public OpConversionPattern { LogicalResult matchAndRewrite(ONNXReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputType = - adaptor.getData().getType().template dyn_cast(); + auto inputType = dyn_cast(adaptor.getData().getType()); if (!inputType) return rewriter.notifyMatchFailure(op, "input type not a ranked tensor."); - auto outputType = this->getTypeConverter() - ->convertType(op.getResult().getType()) - .template cast(); + auto outputType = cast( + this->getTypeConverter()->convertType(op.getResult().getType())); return (*lowerFn)(op, inputType, outputType, rewriter); } diff --git a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp index 44f3ad5f02..f7be640f7b 100644 --- a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp +++ b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp @@ -41,14 +41,14 @@ LogicalResult handleIncludePadAttr( return rewriter.notifyMatchFailure(op, "Could not infer shapes"); } - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (inputType.getShape().size() != 4) { return rewriter.notifyMatchFailure(op, "TOSA only supports 2d pooling"); } llvm::SmallVector pads = tosa::createOrderedPadAttrForWindowBasedOps(rewriter, - input.getType().cast().getShape(), shapeHelper, + cast(input.getType()).getShape(), shapeHelper, /*ceilMode*/ 0, {0, 1, 2, 3}); // Create Padding and ConstPad tosa::ConstOp's diff --git a/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp b/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp index d730539405..8261445c4e 100644 --- a/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp +++ b/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp @@ -36,7 +36,7 @@ class ONNXMaxPoolSingleOutOpLoweringToTOSA : public ConversionPattern { using OpAdaptor = typename ONNXMaxPoolSingleOutOp::Adaptor; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto maxpoolOp = llvm::cast(op); + auto maxpoolOp = mlir::cast(op); OpAdaptor adaptor(operands, op->getAttrDictionary()); Value input = adaptor.getX(); diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp index d9517d80c9..e81f5fba10 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp @@ -57,9 +57,9 @@ std::optional convertGatherOp(PatternRewriter &rewriter, Location loc, TosaBuilder tosaBuilder(rewriter, loc); - auto resultType = resultValue.getType().dyn_cast(); - auto inputType = inputValue.getType().dyn_cast(); - auto indicesType = indicesValue.getType().dyn_cast(); + auto resultType = dyn_cast(resultValue.getType()); + auto inputType = dyn_cast(inputValue.getType()); + auto indicesType = dyn_cast(indicesValue.getType()); if (!resultType || !inputType || !indicesType) return std::nullopt; @@ -143,7 +143,7 @@ std::optional convertGatherOp(PatternRewriter &rewriter, Location loc, // onnx allows i64 indices, but tosa does not. if (indicesType.getElementType().isInteger(64)) { indicesType = - indicesType.clone(rewriter.getI32Type()).dyn_cast(); + dyn_cast(indicesType.clone(rewriter.getI32Type())); indicesValue = CreateOpAndInfer( rewriter, loc, indicesType, indicesValue) .getResult(); @@ -293,7 +293,7 @@ std::optional convertReduceOpCommon(PatternRewriter &rewriter, bool is_quantized, double input_scale, int64_t input_zp, double output_scale, int64_t output_zp) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -362,14 +362,14 @@ std::optional convertReduceMeanOp(PatternRewriter &rewriter, // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis) RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + isa(input_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + isa(output_type.getElementType()); if (input_is_qtype != output_is_qtype) { op->emitOpError("ConvertReduceSumOp: input/output tensor should " @@ -378,7 +378,7 @@ std::optional convertReduceMeanOp(PatternRewriter &rewriter, } // Only supports float type mean() if it's non-quantized - if (!input_is_qtype && !output_type.getElementType().isa()) { + if (!input_is_qtype && !isa(output_type.getElementType())) { op->emitWarning( "Failed convertReduceMean: input unquantized type but output element " "not FloatType!"); @@ -403,9 +403,9 @@ std::optional convertReduceMeanOp(PatternRewriter &rewriter, if (input_is_qtype) { auto input_qtype = - input_type.getElementType().cast(); + cast(input_type.getElementType()); auto output_qtype = - output_type.getElementType().cast(); + cast(output_type.getElementType()); // Combine 'div_scale' as part of output rescale output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale(); diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp index a9e792baef..56d7bddac8 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -87,12 +87,12 @@ namespace onnx_mlir { //===----------------------------------------------------------------------===// inline bool isTOSABool(mlir::Type type) { - mlir::IntegerType intType = type.dyn_cast(); + mlir::IntegerType intType = mlir::dyn_cast(type); return intType && intType.isSignless() && intType.getWidth() == 1; } inline bool isTOSAInt(mlir::Type type) { - mlir::IntegerType intType = type.dyn_cast(); + mlir::IntegerType intType = mlir::dyn_cast(type); std::set intWidth{1, 8, 16, 32, 48, 64}; return intType && (intType.isSignless() || intType.isUnsignedInteger()) && (intWidth.find(intType.getWidth()) != intWidth.end()); diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc index aa326a28b0..22ae929f9f 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc @@ -126,13 +126,13 @@ llvm::SmallVector createOrderedPadAttrForWindowBasedOps( inline mlir::LogicalResult getAvgPool2dAccType(mlir::PatternRewriter &rewriter, mlir::Value input, mlir::TypeAttr &accType) { - auto inputTy = llvm::dyn_cast(input.getType()); + auto inputTy = mlir::dyn_cast(input.getType()); if (!inputTy) return mlir::failure(); auto inputETy = inputTy.getElementType(); if (auto quantType = - llvm::dyn_cast(inputETy)) + mlir::dyn_cast(inputETy)) inputETy = quantType.getStorageType(); // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time @@ -180,7 +180,7 @@ mlir::FailureOr convertPoolOp( llvm::SmallVector kernelShapeVec; llvm::transform(kernelShape, std::back_inserter(kernelShapeVec), [](const mlir::Attribute &pad) { - return pad.cast().getInt(); + return mlir::cast(pad).getInt(); }); const int64_t ceilMode = adaptor.getCeilMode(); @@ -216,7 +216,7 @@ mlir::FailureOr convertPoolOp( pads[0], pads[2] + ceilConstants[0], pads[1], pads[3] + ceilConstants[1]}; mlir::FailureOr resizedInput = tosaBuilder.resizeWindowBasedOps( - input, input.getType().cast().getShape(), + input, mlir::cast(input.getType()).getShape(), {kernelShapeVec[0], kernelShapeVec[1]}, reorderedPads, shapeHelper.strides, shapeHelper.dilations); @@ -257,4 +257,4 @@ mlir::FailureOr convertPoolOp( // Construct the old result shape out of the new one mlir::Value transpose = tosaBuilder.transpose(input, {0, 3, 1, 2}); return transpose; -}; +} diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp index 0da45a365c..6c1501294f 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp @@ -99,7 +99,7 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op, Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, Value input_val, double input_scale, int64_t input_zp) { // Output is always int32 type - auto input_type = input_val.getType().dyn_cast(); + auto input_type = dyn_cast(input_val.getType()); assert(input_type); auto output_type = input_type.clone(rewriter.getI32Type()); diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp index 8cb4f4db65..1bb43eebd4 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp @@ -63,11 +63,11 @@ mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc, // op. This allows shape inference during the framework to TOSA lowering. template TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc, - mlir::Type result_ty, Args &&... args) { + mlir::Type result_ty, Args &&...args) { auto op = rewriter.create(loc, result_ty, args...); mlir::InferShapedTypeOpInterface shapeInterface = - llvm::dyn_cast(op.getOperation()); + mlir::dyn_cast(op.getOperation()); if (!shapeInterface) return op; @@ -92,7 +92,7 @@ TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc, template void CreateReplaceOpAndInfer(mlir::PatternRewriter &rewriter, - mlir::Operation *op, mlir::Type result_ty, Args &&... args) { + mlir::Operation *op, mlir::Type result_ty, Args &&...args) { auto result = CreateOpAndInfer(rewriter, op->getLoc(), result_ty, args...); rewriter.replaceOp(op, result->getResults()); diff --git a/src/Conversion/ONNXToTOSA/Tensor/Concat.cpp b/src/Conversion/ONNXToTOSA/Tensor/Concat.cpp index d8f557fa81..76fa0f9b8e 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Concat.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Concat.cpp @@ -46,7 +46,7 @@ class ONNXConcatLoweringToTOSA : public OpConversionPattern { Type newConcatOutputType = RankedTensorType::get( llvm::SmallVector(inputRank, ShapedType::kDynamic), - resultType.cast().getElementType()); + cast(resultType).getElementType()); tosa::CreateReplaceOpAndInfer( rewriter, op, newConcatOutputType, inputs, axis); diff --git a/src/Conversion/ONNXToTOSA/Tensor/Expand.cpp b/src/Conversion/ONNXToTOSA/Tensor/Expand.cpp index 133e0c4c87..941fa76d62 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Expand.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Expand.cpp @@ -60,9 +60,9 @@ class ONNXExpandLoweringToTOSA : public OpConversionPattern { castArrayRef(shapeWideNums.get()); auto inputType = - llvm::dyn_cast_or_null(adaptor.getInput().getType()); + mlir::dyn_cast_or_null(adaptor.getInput().getType()); auto outputType = - llvm::dyn_cast_or_null(op.getResult().getType()); + mlir::dyn_cast_or_null(op.getResult().getType()); if (!inputType || !outputType || !inputType.hasStaticShape() || !outputType.hasStaticShape()) { return rewriter.notifyMatchFailure( diff --git a/src/Conversion/ONNXToTOSA/Tensor/Flatten.cpp b/src/Conversion/ONNXToTOSA/Tensor/Flatten.cpp index 315fb64011..e6207dee61 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Flatten.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Flatten.cpp @@ -44,7 +44,7 @@ class ONNXFlattenLoweringToTOSA : public OpConversionPattern { TosaBuilder tosaBuilder(rewriter, loc); int64_t axis = adaptor.getAxis(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); // onnx allows values beetween [-r, r] where r is the rank. if (axis == inputType.getRank()) { diff --git a/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp b/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp index dbb3ea20e0..c34525872a 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp @@ -57,7 +57,7 @@ class ONNXGatherLoweringToTOSA : public OpConversionPattern { // onnx allows values beetween [-r, r-1] where r is the rank axis = tosa::convertNegativeAxis(axis, inputRank); - auto indicesType = indices.getType().cast(); + auto indicesType = cast(indices.getType()); APInt indicesVal; if (indicesType.getRank() == 0 && @@ -76,7 +76,7 @@ class ONNXGatherLoweringToTOSA : public OpConversionPattern { SmallVector newIndicesValues; newIndicesValues.resize(indicesType.getNumElements()); - ArrayRef inputShape = inputType.cast().getShape(); + ArrayRef inputShape = cast(inputType).getShape(); // ONNX allows negative indices and TOSA doesn't. // We will emit ops to compute diff --git a/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp b/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp index 6caf6fd027..c9fc9b6ea1 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp @@ -79,11 +79,11 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern { getTypeConverter()->convertType(op.getResult().getType()); float valueFloat = 0.0F; - if (!constValue.getType().dyn_cast()) { + if (!isa(constValue.getType())) { auto valueAttr = tosa::getValueFromTosaConst(constValue); auto valueIt = valueAttr.getValues().begin(); // Need float for F32 Type - float valueFloat = (*valueIt).cast().getValueAsDouble(); + float valueFloat = cast(*valueIt).getValueAsDouble(); TosaBuilder tosaBuilder(rewriter, loc); Value constTosaTensor = diff --git a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp index b2b6f17c10..1c5936bab2 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp @@ -70,7 +70,7 @@ ScaleHelper normalize(int64_t output, int64_t input, bool pytorchHalfPixel, // We can compute this directly based on previous values. border = denominator * (output - 1) - numerator * (input - 1) + offset; return ScaleHelper(numerator, denominator, offset, border); -}; +} void valuesFromAxis(ArrayAttr *axis, llvm::SmallVectorImpl &axisVec) { auto axisRange = axis->getAsRange(); @@ -172,7 +172,7 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto resizeOp = llvm::cast(op); + auto resizeOp = mlir::cast(op); Location loc = op->getLoc(); OpAdaptor adaptor(operands, op->getAttrDictionary()); diff --git a/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp b/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp index fb5213de49..514b0e5cda 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp @@ -48,7 +48,7 @@ class ONNXTransposeLoweringToTOSA Value input = adaptor.getData(); - auto inputType = input.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); if (!inputType) return rewriter.notifyMatchFailure(op, "input not a ranked tensor"); @@ -61,7 +61,7 @@ class ONNXTransposeLoweringToTOSA op, "input element type not supported"); } - auto outputType = op.getResult().getType().dyn_cast(); + auto outputType = dyn_cast(op.getResult().getType()); if (!outputType) return rewriter.notifyMatchFailure(op, "output not a ranked tensor");