Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize Operand Quantization in FuseQuantizeOps #3327

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 99 additions & 82 deletions lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <stack>

using namespace mlir;
using namespace mlir::torch;
Expand All @@ -27,98 +28,112 @@ template <typename SrcOp> struct QuantInfo {
template <> struct QuantInfo<AtenReluOp> {
static constexpr unsigned operandsToQuantize[1] = {0};
};
template <typename SrcOp>
class QuantizeOperands : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());

bool dequanted = false;
auto f = [&dequanted](Value operand) {
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
if (auto dequant = operand.getDefiningOp<AtenDequantizeSelfOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
return operand;
};

for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
operands[i] = f(operands[i]);
}

if (!dequanted) {
return rewriter.notifyMatchFailure(op, "no dequantizations found");
}

rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
};
// A QCommutingOp is an Op satisfying:
// 1. Has at most one tensor operand at index 0
// 2. Has a single output, which is a tensor
// 3. Satisfies the commutation relation:
// [MPTQT -> Dequant -> Op(float)] = [Op(int) -> MPTQT -> Dequant]
// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp"
// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp"
bool isQCommutingOp(mlir::Operation *op) {
// if adding a new commuting op here, be sure to add a
// RemoveUnused pattern for that op to clean up afterwards
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp>(op);
}

template <typename SrcOp>
class QuantizeTransposedOperands : public OpRewritePattern<SrcOp> {
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... ->
// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops
// {Op1,Op2,...,Opk} with k <= depth.
// With depth = 0, this conversion will simply fuse any immediately quantizable
// operands: [MPTQT -> Dequant -> SrcOp (float operands)] to [MPTQT -> SrcOp(int
// operands)]
template <typename SrcOp, unsigned depth>
class QuantizeOperandsPastCommutingOps : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {

mlir::Location loc = op.getLoc();
llvm::SmallVector<Value> operands(op->getOperands());
unsigned numOperands = operands.size();
bool dequanted = false;
for (unsigned i = 0; i < numOperands; i++) {
if (auto trans = operands[i].getDefiningOp<AtenTransposeIntOp>()) {
auto transOperands = trans.getOperands();
Value dequantOperand;
if (auto dequant =
transOperands[0].getDefiningOp<AtenDequantizeSelfOp>()) {
dequantOperand = dequant.getOperand();
if (auto quant =
dequantOperand
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
auto quantOperands = quant.getOperands();
auto qType = quantOperands[0]
.getType()
.cast<ValueTensorType>()
.getOptionalDtype();
auto torchQType =
cast<ValueTensorType>(quant.getType()).getOptionalDtype();
auto transQTy =
rewriter.getType<ValueTensorType>(trans.getResult()
.getType()
.cast<ValueTensorType>()
.getOptionalSizes(),
qType);
auto newQuantTy =
rewriter.getType<ValueTensorType>(trans.getResult()
.getType()
.cast<ValueTensorType>()
.getOptionalSizes(),
torchQType);
Value newTrans = rewriter.create<AtenTransposeIntOp>(
op.getLoc(), transQTy, quantOperands[0], transOperands[1],
transOperands[2]);
Value newQuant =
rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), newQuantTy, newTrans, quantOperands[1],
quantOperands[2]);
operands[i] = newQuant;
dequanted = true;
}

for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
Value operand = operands[i];
std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd;
for (unsigned k = 0; k < depth + 1; k++) {
auto currOp = operand.getDefiningOp();
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
if (!currOp)
break;
// Case 1 : currOp is a q commuting op (continue loop)
if (isQCommutingOp(currOp)) {
commutingOpStack.push(currOp);
// set operand to currOp for next k-iteration
operand = currOp->getOperand(0);
continue;
}
// Case 2 : currOp is a dequant op (end loop)
if (llvm::isa<AtenDequantizeSelfOp, AtenDequantizeTensorOp>(currOp)) {
dequantOpd = currOp->getOperand(0);
auto MPTQTOp =
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
MPTQTOpd = MPTQTOp.getOperand(0);
}
// either a dequant was found or chain broken, so break loop
break;
}

// move to next operand if this trace was unsuccessful
if (!MPTQTOpd)
continue;

// a successful trace occured, so set dequant to true
dequanted = true;

// rewrite stack
Value oldOpd = MPTQTOpd;
Type intDType =
cast<ValueTensorType>(MPTQTOpd.getType()).getOptionalDtype();
while (!commutingOpStack.empty()) {
// get front of the commuting op stack and replace its first operand
// with oldOpd
auto currOp = commutingOpStack.top();
commutingOpStack.pop();
llvm::SmallVector<Value> currOperands(currOp->getOperands());
currOperands[0] = oldOpd;
// get new result type
auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]);
auto intType =
rewriter.getType<ValueTensorType>(oldType.getSizes(), intDType);
// rewrite currOp to have new operands and result type
// store this as oldOpd for next loop
oldOpd = rewriter
.create(loc, (currOp->getName()).getIdentifier(),
currOperands, intType, currOp->getAttrs())
->getResult(0);
}

// stack is empty, so oldOpd is now the corrected verion of the
// SrcOp's original operand
// convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp
auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands();
auto qTorchType =
cast<ValueTensorType>(dequantOpd.getType()).getOptionalDtype();
auto newMPTQTType = rewriter.getType<ValueTensorType>(
cast<ValueTensorType>(operands[i].getType()).getSizes(), qTorchType);
operands[i] = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]);
}

if (!dequanted) {
return rewriter.notifyMatchFailure(
op, "no dequantized transpose inputs found.");
return rewriter.notifyMatchFailure(op, "No dequantizations found.");
}

rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
Expand Down Expand Up @@ -356,11 +371,13 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>,
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>,
QuantizeOperands<AtenMatmulOp>, QuantizeOperands<AtenReluOp>,
QuantizeTransposedOperands<AtenMatmulOp>,
QuantizeAccumulator<AtenMatmulOp>, QuantizeOperands<AtenMmOp>,
QuantizeTransposedOperands<AtenMmOp>, QuantizeAccumulator<AtenMmOp>,
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
RemoveUnused<AtenReshapeOp>,
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 0>,
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
QuantizeOperandsPastCommutingOps<AtenMmOp, 1>,
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
context);

Expand Down
Loading
Loading