Skip to content

Commit

Permalink
extractslice at lower bitwidth
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Jun 3, 2024
1 parent 4fddac0 commit 1c11573
Showing 1 changed file with 101 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,104 @@ struct FoldSuccessiveTensorInsertSliceOps
}
};

//===----------------------------------------------------------------------===//
// ExtractAtSmallBitwidth
//===----------------------------------------------------------------------===//

///
/// ```
/// %0 = linalg.generic i8 -> f32
/// %1 = tensor.extract_slice
/// %2 = linalg.generic f32 -> i8
/// ```
///
/// to
///
/// ```
/// %0 = linalg.generic i8 -> f32
/// %new1 = linalg.generic f32 -> i8
/// %1 = tensor.extract_slice i8
/// %new2 = linalg.generic i8 -> f32
/// %2 = linalg.generic f32 -> i8
/// ```
struct ExtractAtSmallBitwidth
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {

auto sourceGenericOp =
sliceOp.getSource().getDefiningOp<linalg::GenericOp>();
if (!sourceGenericOp || !isDequantizationLikeOp(sourceGenericOp)) {
return rewriter.notifyMatchFailure(
sliceOp, "expected source to be dequantize-like generic op");
}

auto sourceResult = sourceGenericOp.getResult(0);
auto sourceNarrowType = dyn_cast<RankedTensorType>(sourceResult.getType());
auto sourceWideType = cast<RankedTensorType>(
sourceGenericOp.getDpsInputs().front().getType());
auto extractedNarrowType = RankedTensorType::get(
sliceOp.getResultType().getShape(), sourceNarrowType.getElementType());

// if (!sourceWideType.getElementType().isF32() ||
// !sourceNarrowType.getElementType().isInteger() ||
// sourceNarrowType.getElementType().getIntOrFloatBitWidth() >=
// sourceWideType.getElementType().getIntOrFloatBitWidth()) {
// return rewriter.notifyMatchFailure(
// sliceOp, "expected int to float conversion with widening
// bitwidth");
// }

// Create a `linalg.generic` to narrow the element bitwidth back to the
// original size
Value emptyOp = rewriter.create<tensor::EmptyOp>(
sliceOp.getLoc(), sourceWideType.getShape(),
sourceWideType.getElementType());

auto newGeneric = rewriter.create<linalg::GenericOp>(
sliceOp.getLoc(), sourceWideType, sourceGenericOp.getResults(), emptyOp,
sourceGenericOp.getIndexingMapsArray(),
sourceGenericOp.getIteratorTypesArray(),
[&](OpBuilder &nestedBuilder, Location loc, ValueRange args) {
// Custom region for f32 -> i8 conversion
auto castOp = nestedBuilder.create<arith::FPToSIOp>(
loc, sourceWideType.getElementType(), args[0]);
nestedBuilder.create<linalg::YieldOp>(loc, castOp.getResult());
});

// Create a new slice that slices at the lower bitwidth
auto newSliceType = RankedTensorType::get(sliceOp.getType().getShape(),
sourceWideType.getElementType());
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), newSliceType, newGeneric.getResult(0),
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
sliceOp.getMixedStrides());

// Finally convert back to the wider bitwidth
auto identityMapTwo =
rewriter.getMultiDimIdentityMap(sliceOp.getResultType().getRank());
SmallVector<AffineMap> indexingMapsTwo = {identityMapTwo, identityMapTwo};
SmallVector<utils::IteratorType> iteratorTypesTwo(
sliceOp.getResultType().getRank(), utils::IteratorType::parallel);

Value emptyTwo = rewriter.create<tensor::EmptyOp>(
sliceOp.getLoc(), sliceOp.getResultType().getShape(),
sliceOp.getResultType().getElementType());
auto castBackOp = rewriter.create<linalg::GenericOp>(
sliceOp.getLoc(), extractedNarrowType, newSliceOp.getResult(), emptyTwo,
indexingMapsTwo, iteratorTypesTwo,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange args) {
auto castOp = nestedBuilder.create<arith::SIToFPOp>(
loc, sourceNarrowType.getElementType(), args[0]);
nestedBuilder.create<linalg::YieldOp>(loc, castOp.getResult());
});

rewriter.replaceOp(sliceOp, castBackOp);
return success();
}
};

//===----------------------------------------------------------------------===//
// GatherFusionPattern
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -209,9 +307,9 @@ struct FusionPreprocessingPass
FusionPreprocessingPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<FoldSuccessiveTensorInsertSliceOps,
GenericOpInterchangePattern, GatherFusionPattern>(
&getContext());
patterns
.add<FoldSuccessiveTensorInsertSliceOps, GenericOpInterchangePattern,
GatherFusionPattern, ExtractAtSmallBitwidth>(&getContext());

// Fold away `tensor.dim` operations that can be resolved in terms of its
// operand shapes.
Expand Down

0 comments on commit 1c11573

Please sign in to comment.