Skip to content

Commit

Permalink
Cleanup splitting strategy and make it more reusable (#11712)
Browse files Browse the repository at this point in the history
This has the side effect of avoiding forcing loop unrolling on the first
part of the split.
  • Loading branch information
nicolasvasilache authored Jan 5, 2023
1 parent 52643c3 commit 663b1a7
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,40 @@ auto unpackRegisteredMatchCallback(ImplicitLocOpBuilder &b,
return std::tuple_cat(a);
}

/// Compute the (splitPoint, vectorSize) pair to break [0 .. upperBound] into
/// [0 .. splitPoint] and [splitPoint + 1 .. upperBound] such that `splitPoint`
/// is a multiple of `minMultiple * vectorSize`.
/// `vectorSize` is the maximal power of `2`, smaller than `maxVectorSize`, for
/// which `splitPoint` can be computed.
///
/// If such a positive multiple exists:
/// 1. if it is `upperBound`, then `upperBound` is an even multiple of
/// `minMultiple` * `vectorSize` and we can tile evenly without splitting.
/// In this case we return (0, vectorSize).
/// 2. otherwise, it is a split point at which we can split with vectorSize
/// to obtain the largest divisible tiling.
/// In this case we return (splitPoint, vectorSize).
/// Otherwise we return (0, 1) to signify no splitting and a vector size of 1.
// TODO: support the dynamic case, taking future stride and alignment into
// account and returning Values. The op then needs to become part of the
// transform dialect.
static std::pair<int64_t, int64_t> computeSplitPoint(int64_t upperBound,
int64_t minMultiple,
int64_t maxVectorSize) {
assert((maxVectorSize & (maxVectorSize - 1)) == 0 && "must be a power of 2");
if (ShapedType::isDynamic(upperBound)) return std::make_pair(0l, 1l);
for (int64_t vectorSize = maxVectorSize; vectorSize >= 1; vectorSize >>= 1) {
int64_t splitPoint =
iree_compiler::previousMultipleOf(upperBound, minMultiple * vectorSize);
if (splitPoint > 0) {
return (upperBound == splitPoint)
? std::make_pair(0l, vectorSize)
: std::make_pair(splitPoint, vectorSize);
}
}
return std::make_pair(0l, 1l);
}

/// Post-bufferization mapping to blocks and threads.
/// Takes a handle to a func.func and returns an updated handle to a
/// func.func.
Expand Down Expand Up @@ -445,75 +479,69 @@ createReductionStrategyStagedThreadDistributionStep(
blockCombinerOpH);
}

/// Search for the maximal `numThreadsXInBlock` times `vectorSize` smaller than
/// `value`; where `vectorSize` is in `{4, 2, 1}`.
/// If such a positive multiple exists:
/// 1. if it is `value`, then value is an even multiple of
/// `numThreadsXInBlock` * `vectorSize` and we can tile evenly without
/// splitting.
/// 2. otherwise, it is a split point at which we can split with vectorSize
/// to obtain the largest divisible tiling.
/// Return splitPoint and vector size.
static std::pair<int64_t, int64_t> computeElementwiseSplitPoint(
int64_t value, int numThreadsXInBlock) {
int64_t zero = 0;
if (ShapedType::isDynamic(value)) return std::make_pair(zero, 1);
for (int64_t vectorSize : {4, 2, 1}) {
int64_t splitPoint = iree_compiler::previousMultipleOf(
value, numThreadsXInBlock * vectorSize);
if (splitPoint > 0) {
return (value == splitPoint) ? std::make_pair(zero, vectorSize)
: std::make_pair(splitPoint, vectorSize);
}
}
return std::make_pair(zero, 1);
}

static void createElementwiseStrategyThreadStep(
/// Given a handle `elementwiseH` to an elementwise op of rank `rank`, sizes
/// `elementwiseSizes` mapped to `numThreadsXInBlock` threads along dimension x.
/// Build a schedule that maps the most minor dimension to a scf.foreach op
/// itself mapped to the `gpu.thread x` dimension.
/// The schedule first performs a split of the largest possible multiple of
/// `numThreadsXInBlock * maxVectorSize` to form a maximally divisible region
/// Assumes the most minor dimension of the op is the last one.
// TODO: More robustness wrt selecting the most minor dimension otherwise
// performance may suffer.
// TODO: Split point should be dynamic and aware of future stride / alignment
// to also guarantee proper vector alignments.
static void create1DSplittingStrategyWithOptionalThreadMapping(
ImplicitLocOpBuilder &b, Value elementwiseH, int64_t rank,
SmallVector<int64_t> elementwiseSizes, int64_t numThreadsXInBlock) {
assert(rank > 0 && "nonnegative rank expected");
SmallVector<int64_t> trailingTileSizes(rank, 0);
// The following assumes we only want to tile the most-minor dimension of the
// trailing operation. This may be a completely wrong choice.
// TODO: More robustness to permutations of most-minor dimensions.
// TODO: Split point should be aware of future stride / alignment.
SmallVector<int64_t> elementwiseSizes, int64_t numThreads,
int64_t maxVectorSize = 4) {
if (rank == 0) return;

int64_t mostMinorDim = rank - 1;
int64_t mostMinorSize = elementwiseSizes[mostMinorDim];
auto [splitPoint, vectorSize] =
computeElementwiseSplitPoint(mostMinorSize, numThreadsXInBlock);
computeSplitPoint(mostMinorSize, numThreads, maxVectorSize);

SmallVector<int64_t> scfForTileSizes = trailingTileSizes,
foreachTileSizes = trailingTileSizes;
scfForTileSizes[mostMinorDim] = numThreadsXInBlock * vectorSize;
foreachTileSizes[mostMinorDim] = numThreadsXInBlock;
SmallVector<int64_t> scfForTileSizes(rank, 0), foreachTileSizes(rank, 0);
scfForTileSizes[mostMinorDim] = numThreads * vectorSize;
foreachTileSizes[mostMinorDim] = numThreads;

auto threadX = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
mlir::gpu::Threads::DimX);
// Split, tile and map the most minor dimension to `gpu.thread x`.
if (splitPoint > 0) {
auto pdlOperation = pdl::OperationType::get(b.getContext());
auto split =
b.create<transform::SplitOp>(pdlOperation, pdlOperation, elementwiseH,
b.getI64IntegerAttr(mostMinorDim), Value(),
b.getI64IntegerAttr(splitPoint));
elementwiseH = split.getFirst();
if (vectorSize > 1) {
auto res = iree_compiler::buildTileFuseToScfFor(
b, elementwiseH, {},
getAsOpFoldResult(b.getI64ArrayAttr({scfForTileSizes})));
elementwiseH = res.tiledOpH;
}
if (numThreads > 1) {
iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
b, elementwiseH, {},
getAsOpFoldResult(b.getI64ArrayAttr(foreachTileSizes)),
b.getArrayAttr({threadX}));
}
elementwiseH = split.getSecond();
}
// Tile and map the most minor dimension of the remainder to `gpu.thread x`.
if (vectorSize > 1) {
auto res = iree_compiler::buildTileFuseToScfFor(
b, elementwiseH, {},
getAsOpFoldResult(b.getI64ArrayAttr({scfForTileSizes})));
elementwiseH = res.tiledOpH;
}
if (numThreads > 1) {
iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
b, res.tiledOpH, {},
b, elementwiseH, {},
getAsOpFoldResult(b.getI64ArrayAttr(foreachTileSizes)),
b.getArrayAttr({threadX}));
elementwiseH = split.getSecond();
}

auto res = iree_compiler::buildTileFuseToScfFor(
b, elementwiseH, {},
getAsOpFoldResult(b.getI64ArrayAttr({scfForTileSizes})));
iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
b, res.tiledOpH, {},
getAsOpFoldResult(b.getI64ArrayAttr(foreachTileSizes)),
b.getArrayAttr({threadX}));
}

static void createReductionStrategyStagedThreadDistribution(
Expand All @@ -523,7 +551,7 @@ static void createReductionStrategyStagedThreadDistribution(
// Map the potential maybeTiledLeadingH.
// TODO: Consider fusing leading elementwise into threads.
if (strategy.captures.maybeLeadingRank > 0) {
createElementwiseStrategyThreadStep(
create1DSplittingStrategyWithOptionalThreadMapping(
b, maybeTiledLeadingH, strategy.captures.maybeLeadingRank,
strategy.captures.leadingOpSizes, strategy.getNumThreadsXInBlock());
}
Expand All @@ -545,7 +573,7 @@ static void createReductionStrategyStagedThreadDistribution(

// Map the potential maybeTiledTrailingH.
if (strategy.captures.maybeTrailingRank > 0) {
createElementwiseStrategyThreadStep(
create1DSplittingStrategyWithOptionalThreadMapping(
b, maybeTiledTrailingH, strategy.captures.maybeTrailingRank,
strategy.captures.trailingOpSizes, strategy.getNumThreadsXInBlock());
}
Expand Down Expand Up @@ -632,34 +660,29 @@ static void createSmallReductionStrategyThreadDistribution(
maybeLeadingH =
b.create<FuseIntoContainingOp>(maybeLeadingH, tileResult.foreachThreadH);

// Scalarize all ops to ensure vectorization.
// 1. Scalarize all ops to ensure vectorization.
auto pdlOperation = pdl::OperationType::get(b.getContext());
fillH = b.create<ScalarizeOp>(pdlOperation, fillH);
maybeLeadingH = b.create<ScalarizeOp>(pdlOperation, maybeLeadingH);
Value tiledH = b.create<ScalarizeOp>(pdlOperation, tileResult.tiledOpH);
Value fusedH = b.create<ScalarizeOp>(
pdlOperation, tileResult.resultingFusedOpsHandles.front());

auto [blockReductionH, maybeBlockTrailingH] =
iree_compiler::buildSelectFirstNonEmpty(b, fusedH, tiledH);

// Splitting into explicit vector<4> helps a lot on alignment.
// TODO: first split should be dynamic and based on the future stride.
int64_t reductionDimensionSize = strategy.captures.reductionOpSizes.back();
if (ShapedType::isDynamic(reductionDimensionSize)) return;

for (int64_t i = 0, e = (reductionDimensionSize - 1) / 4; i < e; ++i) {
auto split = b.create<transform::SplitOp>(
pdlOperation, pdlOperation, blockReductionH,
b.getI64IntegerAttr(strategy.captures.reductionRank - 1), Value(),
b.getI64IntegerAttr(4));
blockReductionH = split.getSecond();
auto split2 = b.create<transform::SplitOp>(
pdlOperation, pdlOperation, maybeBlockTrailingH,
b.getI64IntegerAttr(strategy.captures.reductionRank - 1), Value(),
b.getI64IntegerAttr(4));
maybeBlockTrailingH = split2.getSecond();
}
// 2. Apply the 1d splitting strategy to the reduction part while specifying
// a single thread. This triggers the splitting but not the thread mapping
// part.
create1DSplittingStrategyWithOptionalThreadMapping(
b, blockReductionH, strategy.captures.reductionRank,
strategy.captures.reductionOpSizes,
/*numThreads=*/1);

// 3. apply the 1d splitting strategy to the trailing elementwise.
create1DSplittingStrategyWithOptionalThreadMapping(
b, maybeBlockTrailingH, strategy.captures.maybeTrailingRank,
strategy.captures.trailingOpSizes,
strategy.getNumThreadsInBlock().back());
}

/// Builds the transform IR tiling reductions for CUDA targets. Supports
Expand Down Expand Up @@ -687,7 +710,7 @@ static void createCudaSmallReductionStrategy(
b, maybeLeadingHBlock, gridFillH, gridReductionH,
maybeTiledTrailingHBlock, strategy);

// Step 4-6. Common trailing steps.
// Step 4-5. Common trailing steps.
createCommonTrailingStrategy(b, variantH, strategy);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",

// Small reduction computes the whole reduction on a single thread.
// CHECK-LABEL: func.func @small_reduction
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
// CHECK-NOT: memref.alloc()
// CHECK: gpu.thread_id x
// CHECK: vector.transfer_read {{.*}}: memref<1024x13xf32>, vector<4xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} : vector<4xf32> to f32
// CHECK: vector.transfer_read {{.*}}: memref<1024x13xf32>, vector<4xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} : vector<4xf32> to f32
// CHECK: vector.transfer_read {{.*}}: memref<1024x13xf32>, vector<4xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} : vector<4xf32> to f32
// CHECK: vector.broadcast {{.*}}: f32 to vector<f32>
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C12]] step %[[C4]] {
// CHECK: vector.transfer_read {{.*}}: memref<1024x13xf32>, vector<4xf32>
// CHECK: vector.multi_reduction <add>, %{{.*}} : vector<4xf32> to f32
// CHECK: vector.transfer_write {{.*}} : vector<f32>, memref<f32
// CHECK-NOT: gpu.barrier
// CHECK: vector.transfer_read {{.*}}: memref<f32{{.*}}>, vector<f32>
// CHECK: arith.addf %{{.*}} : f32
// CHECK: vector.transfer_write {{.*}} : vector<f32>, memref<f32

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,23 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// -----


hal.executable @group_reduction_32 {
hal.executable @group_reduction_34 {
hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
hal.executable.export public @group_reduction_32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
hal.executable.export public @group_reduction_34 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @group_reduction_32() {
func.func @group_reduction_34() {
%c0 = arith.constant 0 : index
%cst = arith.constant -0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<8x32xf32>>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<8x34xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<8xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<8x32xf32>> -> tensor<8x32xf32>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [8, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<8x34xf32>> -> tensor<8x34xf32>
%3 = tensor.empty() : tensor<8xf32>
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8xf32>) -> tensor<8xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<8x32xf32>) outs(%4 : tensor<8xf32>) {
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<8x34xf32>) outs(%4 : tensor<8xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.addf %in, %out : f32
linalg.yield %6 : f32
Expand All @@ -128,11 +128,12 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// Overall, the schedule is same as above, but with larger tile sizes.
// Checking only the tile sizes.

// CHECK-LABEL: func.func @group_reduction_32
// CHECK-LABEL: func.func @group_reduction_34
// CHECK: transform.structured.canonicalized_sequence failures(propagate)
// CHECK: transform.iree.tile_to_foreach_thread_and_workgroup_count_region %{{.*}} num_threads [] tile_sizes [64](mapping = [#gpu.block<x>])
// CHECK: transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [64] tile_sizes [](mapping = [#gpu.thread<x>])
// CHECK-COUNT-4: transform.structured.scalarize %{{.*}}
// CHECK-COUNT-14: transform.structured.split %{{.*}} after 4 {dimension = 1 : i64}
// CHECK: transform.structured.split %{{.*}} after 32 {dimension = 1 : i64}
// CHECK: transform.structured.tile %{{.*}}[0, 4]
// CHECK: transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [64, 1, 1]}
// CHECK-NOT: transform.iree.vector.to_warp_execute_on_lane_0

0 comments on commit 663b1a7

Please sign in to comment.