Skip to content

Commit

Permalink
Replace deprecated use of .cast/dyn_cast/isa with the mlir freestandi…
Browse files Browse the repository at this point in the history
…ng functions and misc cleanups.

Also remove some unecessary semicolons and standardize to use mlir instead of llvm casts
  • Loading branch information
jorickert committed Nov 25, 2024
1 parent 3ba5299 commit 02d9b38
Show file tree
Hide file tree
Showing 19 changed files with 62 additions and 65 deletions.
9 changes: 5 additions & 4 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ std::optional<Value> TosaBuilder::gather(Value resultValue, Value inputValue,
Value indicesValue, int32_t batchDims, int32_t axis) {
return tosa::convertGatherOp(rewriter(), loc(), resultValue, inputValue,
indicesValue, batchDims, axis);
};
}

Value TosaBuilder::reshape(mlir::Value value, llvm::ArrayRef<int64_t> shape) {
auto shapeAttr = rewriter().getDenseI64ArrayAttr(shape);
auto valueType = mlir::cast<ShapedType>(value.getType());
Expand Down Expand Up @@ -246,7 +247,7 @@ template Value TosaBuilder::binaryOp<mlir::tosa::PowOp>(

template <typename T>
Value TosaBuilder::unaryOp(mlir::Value &input) {
auto inputType = input.getType().cast<ShapedType>();
auto inputType = cast<ShapedType>(input.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(inputType.getRank(), ShapedType::kDynamic),
inputType.getElementType());
Expand Down Expand Up @@ -304,7 +305,7 @@ Value TosaBuilder::select(
lhs = valueVec[1];
rhs = valueVec[2];
}
auto lhsType = lhs.getType().cast<ShapedType>();
auto lhsType = cast<ShapedType>(lhs.getType());
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
Expand All @@ -326,7 +327,7 @@ mlir::Value TosaBuilder::castToNewTensorElementType(
}

Value TosaBuilder::sqrt(mlir::Value &input) {
auto inputType = input.getType().cast<ShapedType>();
auto inputType = cast<ShapedType>(input.getType());
auto oneHalf = this->getSplattedConst(
0.5, inputType.getElementType(), inputType.getShape());
return this->binaryOp<mlir::tosa::PowOp>(input, oneHalf);
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/Math/Conv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
OpAdaptor adaptor(operands, op->getAttrDictionary());
auto loc = op->getLoc();
auto convOp = llvm::cast<ONNXConvOp>(op);
auto convOp = mlir::cast<ONNXConvOp>(op);

TosaBuilder tosaBuilder(rewriter, loc);

Expand Down
32 changes: 15 additions & 17 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps(
ConversionPatternRewriter &rewriter, Operation *op, OpAdaptorT adaptor,
Type resultType) {
Value lhs = adaptor.getOperands()[0];
auto lhsType = lhs.getType().dyn_cast<TensorType>();
auto lhsType = dyn_cast<TensorType>(lhs.getType());

Value rhs = adaptor.getOperands()[1];
auto rhsType = rhs.getType().dyn_cast<TensorType>();
auto rhsType = dyn_cast<TensorType>(rhs.getType());

auto resultTensorType = resultType.dyn_cast<TensorType>();
auto resultTensorType = dyn_cast<TensorType>(resultType);
if (!lhsType || !rhsType || !resultTensorType) {
return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes");
}
Expand Down Expand Up @@ -121,9 +121,9 @@ class ONNXElementwiseUnaryOpLoweringToTOSA
ConversionPatternRewriter &rewriter) const override {

Value input = *adaptor.getODSOperands(0).begin();
auto inputType = input.getType().dyn_cast<TensorType>();
auto inputType = dyn_cast<TensorType>(input.getType());
Value output = op.getResult();
auto outputType = output.getType().dyn_cast<TensorType>();
auto outputType = dyn_cast<TensorType>(output.getType());

if (!inputType || !outputType) {
return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes");
Expand Down Expand Up @@ -248,7 +248,7 @@ class ONNXLeakyReluOpLoweringToTOSA
LogicalResult matchAndRewrite(ONNXLeakyReluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto outputType = op.getResult().getType().cast<TensorType>();
auto outputType = cast<TensorType>(op.getResult().getType());
if (failed(IsIntOrFloat::checkType(
rewriter, outputType.getElementType(), op))) {
return failure();
Expand Down Expand Up @@ -279,15 +279,13 @@ class ONNXComparisonOpLoweringToTOSA : public OpConversionPattern<OnnxCompOp> {
ConversionPatternRewriter &rewriter) const override {

Value input1 = adaptor.getA();
auto input1ElemType =
input1.getType().template cast<TensorType>().getElementType();
auto input1ElemType = cast<TensorType>(input1.getType()).getElementType();
if (failed(IsIntOrFloat::checkType(rewriter, input1ElemType, op))) {
return failure();
}

Value input2 = adaptor.getB();
auto input2ElemType =
input2.getType().template cast<TensorType>().getElementType();
auto input2ElemType = cast<TensorType>(input2.getType()).getElementType();
if (input1ElemType != input2ElemType) {
return failure();
}
Expand Down Expand Up @@ -432,7 +430,7 @@ class ONNXSqrtOpLoweringToTOSA : public OpConversionPattern<ONNXSqrtOp> {
LogicalResult matchAndRewrite(ONNXSqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto resultTensorType = op.getResult().getType().cast<TensorType>();
auto resultTensorType = cast<TensorType>(op.getResult().getType());
if (failed(IsFloat::checkType(
rewriter, resultTensorType.getElementType(), op))) {
return failure();
Expand All @@ -454,7 +452,7 @@ class ONNXEluOpLoweringToTOSA : public OpConversionPattern<ONNXEluOp> {
// ELU(x) = x if x >= 0
// alpha * (exp(x) - 1.) if x < 0

auto resultTensorType = op.getResult().getType().cast<TensorType>();
auto resultTensorType = cast<TensorType>(op.getResult().getType());
if (failed(IsFloat::checkType(
rewriter, resultTensorType.getElementType(), op))) {
return failure();
Expand Down Expand Up @@ -496,7 +494,7 @@ class ONNXHardSigmoidOpLoweringToTOSA
// - tosa.mul(clamp, alpha)
Value input = adaptor.getX();

auto resultType = op.getResult().getType().template cast<TensorType>();
auto resultType = cast<TensorType>(op.getResult().getType());
auto resultElementType = resultType.getElementType();

TosaBuilder tosaBuilder(rewriter, op->getLoc());
Expand Down Expand Up @@ -536,7 +534,7 @@ class ONNXPReluOpLoweringToTOSA : public OpConversionPattern<ONNXPReluOp> {
LogicalResult matchAndRewrite(ONNXPReluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto outputType = op.getResult().getType().cast<TensorType>();
auto outputType = cast<TensorType>(op.getResult().getType());
if (failed(IsIntOrFloat::checkType(
rewriter, outputType.getElementType(), op))) {
return failure();
Expand All @@ -554,7 +552,7 @@ class ONNXSoftplusOpLoweringToTOSA
LogicalResult matchAndRewrite(ONNXSoftplusOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto outputType = op.getResult().getType().cast<TensorType>();
auto outputType = cast<TensorType>(op.getResult().getType());
if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) {
return failure();
}
Expand All @@ -579,7 +577,7 @@ class ONNXSeluOpLoweringToTOSA : public OpConversionPattern<ONNXSeluOp> {
LogicalResult matchAndRewrite(ONNXSeluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto outputType = op.getResult().getType().cast<TensorType>();
auto outputType = cast<TensorType>(op.getResult().getType());
if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) {
return failure();
}
Expand Down Expand Up @@ -618,7 +616,7 @@ class ONNXThresholdedReluOpLoweringToTOSA
LogicalResult matchAndRewrite(ONNXThresholdedReluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto outputType = op.getResult().getType().cast<TensorType>();
auto outputType = cast<TensorType>(op.getResult().getType());
if (failed(IsIntOrFloat::checkType(
rewriter, outputType.getElementType(), op))) {
return failure();
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/Math/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Value reshapeUpTo3DTensor(Value tensor, TosaBuilder &builder) {
}

return builder.reshape(tensor, newShape);
};
}

// Obtaining the rank broadcasted shapes of tensors makes it easier to
// construct the input and output reshaping logic.
Expand Down
12 changes: 5 additions & 7 deletions src/Conversion/ONNXToTOSA/Math/Reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ DenseIntElementsAttr getAxesLatestsVersionAttr(ONNXReduceOp op) {
if (noOpIfAxesEmpty == 0) {
// Default behaviour when "axes" is none and "noop_with_empty_axes" is
// set to false, it is to reduce all dims
const int64_t numberOfAxes = input.getType().cast<ShapedType>().getRank();
const int64_t numberOfAxes = cast<ShapedType>(input.getType()).getRank();
auto iotaRange =
llvm::iota_range<int64_t>(0, numberOfAxes, /*Inclusive=*/false);
targetAxes = SmallVector<int64_t>(iotaRange.begin(), iotaRange.end());
Expand Down Expand Up @@ -86,7 +86,7 @@ DenseIntElementsAttr getAxesLegacyVersionAttr(ONNXReduceOp op) {
SmallVector<int64_t> targetAxes;
if (!axes) {
// if not present all axes are reduced
const int64_t numberOfAxes = input.getType().cast<ShapedType>().getRank();
const int64_t numberOfAxes = cast<ShapedType>(input.getType()).getRank();
auto iotaRange =
llvm::iota_range<int64_t>(0, numberOfAxes, /*Inclusive=*/false);
targetAxes = SmallVector<int64_t>(iotaRange.begin(), iotaRange.end());
Expand All @@ -111,14 +111,12 @@ class ONNXReduceOpLoweringToTOSA : public OpConversionPattern<ONNXReduceOp> {
LogicalResult matchAndRewrite(ONNXReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto inputType =
adaptor.getData().getType().template dyn_cast<RankedTensorType>();
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
if (!inputType)
return rewriter.notifyMatchFailure(op, "input type not a ranked tensor.");

auto outputType = this->getTypeConverter()
->convertType(op.getResult().getType())
.template cast<RankedTensorType>();
auto outputType = cast<RankedTensorType>(
this->getTypeConverter()->convertType(op.getResult().getType()));

return (*lowerFn)(op, inputType, outputType, rewriter);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/ONNXToTOSA/NN/AveragePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ LogicalResult handleIncludePadAttr(
return rewriter.notifyMatchFailure(op, "Could not infer shapes");
}

auto inputType = input.getType().cast<mlir::TensorType>();
auto inputType = cast<mlir::TensorType>(input.getType());
if (inputType.getShape().size() != 4) {
return rewriter.notifyMatchFailure(op, "TOSA only supports 2d pooling");
}

llvm::SmallVector<int64_t, 4> pads =
tosa::createOrderedPadAttrForWindowBasedOps(rewriter,
input.getType().cast<mlir::TensorType>().getShape(), shapeHelper,
cast<mlir::TensorType>(input.getType()).getShape(), shapeHelper,
/*ceilMode*/ 0, {0, 1, 2, 3});

// Create Padding and ConstPad tosa::ConstOp's
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ONNXMaxPoolSingleOutOpLoweringToTOSA : public ConversionPattern {
using OpAdaptor = typename ONNXMaxPoolSingleOutOp::Adaptor;
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto maxpoolOp = llvm::cast<ONNXMaxPoolSingleOutOp>(op);
auto maxpoolOp = mlir::cast<ONNXMaxPoolSingleOutOp>(op);
OpAdaptor adaptor(operands, op->getAttrDictionary());

Value input = adaptor.getX();
Expand Down
22 changes: 11 additions & 11 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ std::optional<Value> convertGatherOp(PatternRewriter &rewriter, Location loc,

TosaBuilder tosaBuilder(rewriter, loc);

auto resultType = resultValue.getType().dyn_cast<ShapedType>();
auto inputType = inputValue.getType().dyn_cast<RankedTensorType>();
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
auto resultType = dyn_cast<ShapedType>(resultValue.getType());
auto inputType = dyn_cast<RankedTensorType>(inputValue.getType());
auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());

if (!resultType || !inputType || !indicesType)
return std::nullopt;
Expand Down Expand Up @@ -143,7 +143,7 @@ std::optional<Value> convertGatherOp(PatternRewriter &rewriter, Location loc,
// onnx allows i64 indices, but tosa does not.
if (indicesType.getElementType().isInteger(64)) {
indicesType =
indicesType.clone(rewriter.getI32Type()).dyn_cast<RankedTensorType>();
dyn_cast<RankedTensorType>(indicesType.clone(rewriter.getI32Type()));
indicesValue = CreateOpAndInfer<mlir::tosa::CastOp>(
rewriter, loc, indicesType, indicesValue)
.getResult();
Expand Down Expand Up @@ -293,7 +293,7 @@ std::optional<Value> convertReduceOpCommon(PatternRewriter &rewriter,
bool is_quantized, double input_scale, int64_t input_zp,
double output_scale, int64_t output_zp) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;

Expand Down Expand Up @@ -362,14 +362,14 @@ std::optional<Value> convertReduceMeanOp(PatternRewriter &rewriter,
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)

RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;

bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
isa<mlir::quant::UniformQuantizedType>(input_type.getElementType());
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
isa<mlir::quant::UniformQuantizedType>(output_type.getElementType());

if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
Expand All @@ -378,7 +378,7 @@ std::optional<Value> convertReduceMeanOp(PatternRewriter &rewriter,
}

// Only supports float type mean() if it's non-quantized
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
if (!input_is_qtype && !isa<mlir::FloatType>(output_type.getElementType())) {
op->emitWarning(
"Failed convertReduceMean: input unquantized type but output element "
"not FloatType!");
Expand All @@ -403,9 +403,9 @@ std::optional<Value> convertReduceMeanOp(PatternRewriter &rewriter,

if (input_is_qtype) {
auto input_qtype =
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
auto output_qtype =
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());

// Combine 'div_scale' as part of output rescale
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ namespace onnx_mlir {
//===----------------------------------------------------------------------===//

inline bool isTOSABool(mlir::Type type) {
mlir::IntegerType intType = type.dyn_cast<mlir::IntegerType>();
mlir::IntegerType intType = mlir::dyn_cast<mlir::IntegerType>(type);
return intType && intType.isSignless() && intType.getWidth() == 1;
}

inline bool isTOSAInt(mlir::Type type) {
mlir::IntegerType intType = type.dyn_cast<mlir::IntegerType>();
mlir::IntegerType intType = mlir::dyn_cast<mlir::IntegerType>(type);
std::set<unsigned> intWidth{1, 8, 16, 32, 48, 64};
return intType && (intType.isSignless() || intType.isUnsignedInteger()) &&
(intWidth.find(intType.getWidth()) != intWidth.end());
Expand Down
10 changes: 5 additions & 5 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ llvm::SmallVector<int64_t, 4> createOrderedPadAttrForWindowBasedOps(

inline mlir::LogicalResult getAvgPool2dAccType(mlir::PatternRewriter &rewriter,
mlir::Value input, mlir::TypeAttr &accType) {
auto inputTy = llvm::dyn_cast<mlir::ShapedType>(input.getType());
auto inputTy = mlir::dyn_cast<mlir::ShapedType>(input.getType());
if (!inputTy)
return mlir::failure();
auto inputETy = inputTy.getElementType();

if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
mlir::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
inputETy = quantType.getStorageType();

// Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time
Expand Down Expand Up @@ -180,7 +180,7 @@ mlir::FailureOr<mlir::Value> convertPoolOp(
llvm::SmallVector<int64_t, 4> kernelShapeVec;
llvm::transform(kernelShape, std::back_inserter(kernelShapeVec),
[](const mlir::Attribute &pad) {
return pad.cast<mlir::IntegerAttr>().getInt();
return mlir::cast<mlir::IntegerAttr>(pad).getInt();
});

const int64_t ceilMode = adaptor.getCeilMode();
Expand Down Expand Up @@ -216,7 +216,7 @@ mlir::FailureOr<mlir::Value> convertPoolOp(
pads[0], pads[2] + ceilConstants[0], pads[1], pads[3] + ceilConstants[1]};

mlir::FailureOr<mlir::Value> resizedInput = tosaBuilder.resizeWindowBasedOps(
input, input.getType().cast<mlir::RankedTensorType>().getShape(),
input, mlir::cast<mlir::RankedTensorType>(input.getType()).getShape(),
{kernelShapeVec[0], kernelShapeVec[1]}, reorderedPads,
shapeHelper.strides, shapeHelper.dilations);

Expand Down Expand Up @@ -257,4 +257,4 @@ mlir::FailureOr<mlir::Value> convertPoolOp(
// Construct the old result shape out of the new one
mlir::Value transpose = tosaBuilder.transpose(input, {0, 3, 1, 2});
return transpose;
};
}
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Value input_val, double input_scale, int64_t input_zp) {
// Output is always int32 type
auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
auto input_type = dyn_cast<mlir::ShapedType>(input_val.getType());
assert(input_type);
auto output_type = input_type.clone(rewriter.getI32Type());

Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc,
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);

mlir::InferShapedTypeOpInterface shapeInterface =
llvm::dyn_cast<mlir::InferShapedTypeOpInterface>(op.getOperation());
mlir::dyn_cast<mlir::InferShapedTypeOpInterface>(op.getOperation());
if (!shapeInterface)
return op;

Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToTOSA/Tensor/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ONNXConcatLoweringToTOSA : public OpConversionPattern<ONNXConcatOp> {

Type newConcatOutputType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(inputRank, ShapedType::kDynamic),
resultType.cast<ShapedType>().getElementType());
cast<ShapedType>(resultType).getElementType());

tosa::CreateReplaceOpAndInfer<mlir::tosa::ConcatOp>(
rewriter, op, newConcatOutputType, inputs, axis);
Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/ONNXToTOSA/Tensor/Expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class ONNXExpandLoweringToTOSA : public OpConversionPattern<ONNXExpandOp> {
castArrayRef<int64_t, WideNum>(shapeWideNums.get());

auto inputType =
llvm::dyn_cast_or_null<RankedTensorType>(adaptor.getInput().getType());
mlir::dyn_cast_or_null<RankedTensorType>(adaptor.getInput().getType());
auto outputType =
llvm::dyn_cast_or_null<RankedTensorType>(op.getResult().getType());
mlir::dyn_cast_or_null<RankedTensorType>(op.getResult().getType());
if (!inputType || !outputType || !inputType.hasStaticShape() ||
!outputType.hasStaticShape()) {
return rewriter.notifyMatchFailure(
Expand Down
Loading

0 comments on commit 02d9b38

Please sign in to comment.