Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into qfusionupdate2
Browse files Browse the repository at this point in the history
  • Loading branch information
zjgarvey committed May 10, 2024
2 parents 673ffcf + 00efec0 commit 03719af
Show file tree
Hide file tree
Showing 28 changed files with 937 additions and 281 deletions.
12 changes: 9 additions & 3 deletions docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ While this is running, you can already setup the Python venv and dependencies in

## Setup your Python VirtualEnvironment and Dependencies

Also, ensure that you have the appropriate `python-dev` package installed
to access the Python development libraries / headers.

```shell
python -m venv mlir_venv
source mlir_venv/bin/activate
Expand All @@ -26,6 +23,15 @@ python -m pip install -r requirements.txt
python -m pip install -r torchvision-requirements.txt
```

Also, ensure that you have the appropriate `python-dev` package installed
to access the Python development libraries / headers. For example, you can install
it with the following `apt` command on Ubuntu/Debian.

```shell
sudo apt install python3-dev
```


## (Optional) Set up pre-commit

This project uses [pre-commit](https://pre-commit.com/) in its CI. You can
Expand Down
65 changes: 38 additions & 27 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,56 +1463,67 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
});
patterns.onOp(
"CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
Value operand;
Value axisTensor;
Value operand, axisTensor;
int64_t exclusive, reverse;
if (binder.tensorOperands(operand, axisTensor) ||
binder.s64IntegerAttr(exclusive, "exclusive", 0) ||
binder.s64IntegerAttr(reverse, "reverse", 0) ||
binder.tensorResultType(resultType))
return failure();

int64_t exclusive;
int64_t reverse;
// if bind succeeds and either is set, fail because not implemented
if (!binder.s64IntegerAttr(exclusive, "exclusive", 0))
if (exclusive != 0)
return rewriter.notifyMatchFailure(
binder.op, "unsupported onnx.CumSum conversion: exclusive");
if (!binder.s64IntegerAttr(reverse, "reverse", 0))
if (reverse != 0)
return rewriter.notifyMatchFailure(
binder.op, "unsupported onnx.CumSum conversion: reverse");
Torch::BaseTensorType resultTensorType =
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
}

// deal with neg axis: if (axis < 0) axis += rank
int64_t rank =
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
Value rankVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank));
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));

Value axisScalar = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), axisTensor);
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
binder.getLoc(), axisScalar, zero);
binder.getLoc(), axisScalar, cstZero);
isNegative =
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, rankVal);
Value dim = rewriter.create<Torch::AtenAddIntOp>(
Value axis = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), axisScalar, finalOffset);
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Torch::BaseTensorType resultTensorType =
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
Value res;
if (reverse) {
Value dims = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{axis});
Value flip = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), resultType, operand, dims);
Value cumsum = rewriter.create<Torch::AtenCumsumOp>(
binder.getLoc(), resultType, flip, axis, none);
res = rewriter.create<Torch::AtenFlipOp>(binder.getLoc(), resultType,
cumsum, dims);
} else {
res = rewriter.create<Torch::AtenCumsumOp>(
binder.getLoc(), resultType, operand, axis, none);
}
// resultTensorType.print(llvm::outs());
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(binder.op, resultType,
operand, dim, none);

if (exclusive)
res = rewriter.create<Torch::AtenSubTensorOp>(
binder.getLoc(), resultType, res, operand, cstOne);
rewriter.replaceOp(binder.op, res);
return success();
});
patterns.onOp(
Expand Down
23 changes: 16 additions & 7 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (gridShape[3] != 2)
return rewriter.notifyMatchFailure(binder.op,
"gridShape[3] expected to be 2");
std::string mode;
if (binder.customOpNameStringAttr(mode, "mode", "linear"))
std::string iModeString;
int64_t iModeInt;
if (binder.customOpNameStringAttr(iModeString, "mode", "linear"))
return rewriter.notifyMatchFailure(binder.op, "mode bind failure");
if (mode != "linear" && mode != "bilinear")

if (iModeString == "linear" || iModeString == "bilinear") {
iModeInt = 0;
} else if (iModeString == "nearest") {
iModeInt = 1;
} else {
return rewriter.notifyMatchFailure(
binder.op, "currently only mode : linear supported");
binder.op, "currently only mode : linear and nearest supported");
}

std::string padding;
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
return rewriter.notifyMatchFailure(binder.op,
Expand All @@ -143,7 +151,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(

Value interpolationMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));

Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Expand Down Expand Up @@ -651,7 +660,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
result = rewriter.create<Torch::AtenMaximumOp>(
binder.getLoc(), resultType, result, operands[i]);
}
rewriter.replaceOp(binder.op, result.getDefiningOp());
rewriter.replaceOp(binder.op, result);
return success();
});
patterns.onOp(
Expand All @@ -667,7 +676,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
result = rewriter.create<Torch::AtenMinimumOp>(
binder.getLoc(), resultType, result, operands[i]);
}
rewriter.replaceOp(binder.op, result.getDefiningOp());
rewriter.replaceOp(binder.op, result);
return success();
});
patterns.onOp("Neg", 1,
Expand Down
Loading

0 comments on commit 03719af

Please sign in to comment.