Skip to content

Commit

Permalink
Use multi-warp reduction in the staged reduction strategy (#11706)
Browse files Browse the repository at this point in the history
This revision adds support for warp reduction at the whole block level.
The multi-warp reduction allows a simpler strategy that does not need
special handling for the second stage.
It also exhibits more parallelism and removes the need for a sequential
loop to reduce `k * warps` to a `single warp`.
Previously, we would reduce `N -> k-warps -> 1-warp (using 1-warp and k
iterations of a for loop) -> 1 value`.
Now we reduce `N -> k-warps -> k (using k-warps) -> 1 value`.

Additional empirical performance improvements are added on top:
* Use multi-warp reduction in the staged reduction strategy
* Add support to capture all sizes
* Add Split support to the elementwise part
* Tighten the dimensioning of strategies
* Update tests
  • Loading branch information
nicolasvasilache authored Jan 5, 2023
1 parent 56d65bf commit 52643c3
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 342 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ namespace iree_compiler {
// Low-level reusable builder APIs, these should follow MLIR-style builders.
//===----------------------------------------------------------------------===//

static constexpr unsigned kCudaWarpSize = 32;
static constexpr unsigned kCudaMaxNumThreads = 1024;
static constexpr int64_t kCudaWarpSize = 32;
static constexpr int64_t kCudaMaxNumThreads = 1024;

/// Post-bufferization mapping to blocks and threads.
/// Takes a handle to a func.func and returns an updated handle to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// CHECK-NOT: memref.alloc()
// CHECK: gpu.thread_id x
// CHECK: vector.transfer_read {{.*}}: memref<1024x13xf32>, vector<4xf32>
// CHECK: vector.reduction <add>, %{{.*}} : vector<4xf32> into f32
// CHECK: vector.multi_reduction <add>, %{{.*}} : vector<4xf32> to f32
// CHECK: vector.transfer_read {{.*}}: memref<1024x13xf32>, vector<4xf32>
// CHECK: vector.reduction <add>, %{{.*}} : vector<4xf32> into f32
// CHECK: vector.multi_reduction <add>, %{{.*}} : vector<4xf32> to f32
// CHECK: vector.transfer_read {{.*}}: memref<1024x13xf32>, vector<4xf32>
// CHECK: vector.reduction <add>, %{{.*}} : vector<4xf32> into f32
// CHECK: vector.transfer_read {{.*}}: memref<1024x13xf32>, vector<f32>
// CHECK: vector.multi_reduction <add>, %{{.*}} : vector<4xf32> to f32
// CHECK: vector.broadcast {{.*}}: f32 to vector<f32>
// CHECK: arith.addf %{{.*}} : f32
// CHECK: vector.transfer_write {{.*}} : vector<f32>, memref<1024xf32>
// CHECK: vector.transfer_write {{.*}} : vector<f32>, memref<f32

// -----

Expand Down Expand Up @@ -75,162 +75,34 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// CHECK-LABEL: func.func @group_reduction
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32xf32, 3>
// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x

// Fusion occurred, no barrier before the loop
// CHECK-NOT: gpu.barrier
// Local per-thread scf.for-based reduction.
// CHECK: scf.for
// CHECK: vector.transfer_read {{.*}} vector<4xf32>
// CHECK: vector.transfer_read {{.*}} vector<f32>
// CHECK: vector.reduction <add>{{.*}} : vector<4xf32> into f32
// CHECK: vector.broadcast {{.*}} : f32 to vector<f32>
// No barrier within the loop
// CHECK-NOT: gpu.barrier
// CHECK: vector.transfer_read {{.*}} memref<8x64xf32>, vector<f32>
// CHECK: vector.transfer_read {{.*}} memref<1x64xf32, 3>, vector<f32>
// CHECK: arith.addf {{.*}} : f32
// CHECK: vector.transfer_write {{.*}} vector<f32>

// Distributed reduction: everyone loads then 5 xor + addf expected.
// CHECK: vector.transfer_read %{{.*}} memref<8xf32>, vector<f32>
// CHECK: vector.transfer_read %{{.*}} memref<1x32xf32, 3>, vector<1xf32>
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf

// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
// CHECK: scf.if %[[CONDXIS0]]
// CHECK: vector.transfer_write %[[RES_VEC]]
// CHECK: gpu.barrier
// CHECK: memref.dealloc %[[SHMEM_ALLOC]]

// -----

hal.executable @group_reduction_elementwise {
hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
hal.executable.export public @group_reduction_elementwise ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device, %arg1: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @group_reduction_elementwise() {
%c0 = arith.constant 0 : index
%cst = arith.constant -0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<8x64xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<8xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [8, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<8x64xf32>> -> tensor<8x64xf32>
%3 = tensor.empty() : tensor<8xf32>
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8xf32>) -> tensor<8xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<8x64xf32>) outs(%4 : tensor<8xf32>) {
^bb0(%in: f32, %out: f32):
%7 = arith.addf %in, %out : f32
linalg.yield %7 : f32
} -> tensor<8xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%5 : tensor<8xf32>) outs(%3 : tensor<8xf32>) {
^bb0(%in: f32, %out: f32):
%7 = math.sqrt %in : f32
linalg.yield %7 : f32
} -> tensor<8xf32>
flow.dispatch.tensor.store %6, %1, offsets = [0], sizes = [8], strides = [1] : tensor<8xf32> -> !flow.dispatch.tensor<writeonly:tensor<8xf32>>
return
}
}
}
}

// CHECK-LABEL: func.func @group_reduction_elementwise
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x

// Fusion occurred, no barrier before the loop
// CHECK-NOT: gpu.barrier
// Local per-thread scf.for-based reduction.
// CHECK: scf.for
// CHECK: vector.transfer_read {{.*}} vector<4xf32>
// CHECK: vector.transfer_read {{.*}} vector<f32>
// CHECK: vector.reduction <add>{{.*}} : vector<4xf32> into f32
// CHECK: vector.broadcast {{.*}} : f32 to vector<f32>
// No barrier within the loop
// No barrier within the loop.
// CHECK-NOT: gpu.barrier
// CHECK: vector.transfer_write {{.*}} vector<f32>

// Distributed reduction: everyone loads then 5 xor + addf expected.
// CHECK: vector.transfer_read %{{.*}} memref<1xf32, 3>, vector<f32>
// CHECK: vector.transfer_read %{{.*}} memref<1x32xf32, 3>, vector<1xf32>
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf

// CHECK: %[[PARTIAL:.*]] = arith.addf %{{.*}}
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[PARTIAL]] : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
// CHECK: scf.if %[[CONDXIS0]]
// CHECK: vector.transfer_write %[[RES_VEC]]

// CHECK: gpu.barrier
// CHECK: math.sqrt
// CHECK: }
// Barrier after the loop.
// CHECK: gpu.barrier
// CHECK: memref.dealloc %[[SHMEM_ALLOC]]

// -----

hal.executable @group_elementwise_reduction {
hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb", {target_arch = "sm_35"}> {
hal.executable.export public @group_elementwise_reduction ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @group_elementwise_reduction() {
%c0 = arith.constant 0 : index
%cst = arith.constant -0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<8x64xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<8xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [8, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<8x64xf32>> -> tensor<8x64xf32>
%3 = tensor.empty() : tensor<8xf32>
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8xf32>) -> tensor<8xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<8x64xf32>) outs(%4 : tensor<8xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.addf %in, %in : f32
%7 = arith.addf %6, %6 : f32
%8 = arith.addf %7, %out : f32
linalg.yield %8 : f32
} -> tensor<8xf32>
flow.dispatch.tensor.store %5, %1, offsets = [0], sizes = [8], strides = [1] : tensor<8xf32> -> !flow.dispatch.tensor<writeonly:tensor<8xf32>>
return
}
}
}
}

// CHECK-LABEL: func.func @group_elementwise_reduction
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x

// Fusion occurred, no barrier before the loop
// CHECK-NOT: gpu.barrier
// Local per-thread scf.for-based reduction.
// CHECK: scf.for
// CHECK: vector.transfer_read {{.*}} vector<4xf32>
// CHECK: vector.transfer_read {{.*}} vector<f32>
// CHECK: arith.addf{{.*}} : vector<4xf32>
// CHECK: arith.addf{{.*}} : vector<4xf32>
// CHECK: vector.reduction <add>{{.*}} : vector<4xf32> into f32
// CHECK: vector.broadcast {{.*}} : f32 to vector<f32>
// No barrier within the loop
// CHECK-NOT: gpu.barrier
// CHECK: vector.transfer_write {{.*}} vector<f32>

// Distributed reduction: everyone loads then 5 xor + addf expected.
// CHECK: vector.transfer_read %{{.*}} memref<8xf32>, vector<f32>
// CHECK: vector.transfer_read %{{.*}} memref<1x32xf32, 3>, vector<1xf32>
// CHECK: vector.transfer_read %{{.*}} memref<1x64xf32, 3>, vector<1xf32>
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf

// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
// CHECK: arith.minui
// CHECK: memref.load
// CHECK: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
// CHECK: gpu.shuffle idx
// CHECK: %[[RES:.*]] = arith.addf %{{.*}} : f32
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %{{.*}} : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
// CHECK: scf.if %[[CONDXIS0]]
// CHECK: vector.transfer_write %[[RES_VEC]]
Expand Down Expand Up @@ -277,28 +149,35 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// CHECK-LABEL: func.func @group_elementwise_reduction_elementwise
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32xf32, 3>
// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x

// Fusion occurred, no barrier before the loop
// CHECK-NOT: gpu.barrier
// Local per-thread scf.for-based reduction.
// CHECK: scf.for
// CHECK: vector.transfer_read {{.*}} vector<4xf32>
// CHECK: vector.transfer_read {{.*}} vector<f32>
// CHECK: arith.addf{{.*}} : vector<4xf32>
// CHECK: arith.addf{{.*}} : vector<4xf32>
// CHECK: vector.reduction <add>{{.*}} : vector<4xf32> into f32
// CHECK: vector.transfer_read {{.*}} vector<f32>
// CHECK: arith.addf{{.*}} : f32
// CHECK: arith.addf{{.*}} : f32
// CHECK: arith.addf{{.*}} : f32
// CHECK: vector.broadcast {{.*}} : f32 to vector<f32>
// CHECK: vector.transfer_write {{.*}} vector<f32>
// No barrier within the loop
// CHECK-NOT: gpu.barrier
// CHECK: vector.transfer_write {{.*}} vector<f32>
// CHECK: }
// Barrier after the loop
// CHECK: gpu.barrier

// Distributed reduction: everyone loads then 5 xor + addf expected.
// CHECK: vector.transfer_read %{{.*}} memref<1xf32, 3>, vector<f32>
// CHECK: vector.transfer_read %{{.*}} memref<1x32xf32, 3>, vector<1xf32>
// CHECK: vector.transfer_read %{{.*}} memref<1x64xf32, 3>, vector<1xf32>
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf

// CHECK: arith.minui
// CHECK: memref.load
// CHECK: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
// CHECK: gpu.shuffle idx
// CHECK: %[[PARTIAL:.*]] = arith.addf %{{.*}}
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[PARTIAL]] : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
Expand All @@ -323,12 +202,12 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
func.func @group_reduction_larger() {
%c0 = arith.constant 0 : index
%cst = arith.constant -0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<33x256xf32>>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<33x1024xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<33xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [33, 256], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<33x256xf32>> -> tensor<33x256xf32>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [33, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<33x1024xf32>> -> tensor<33x1024xf32>
%3 = tensor.empty() : tensor<33xf32>
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<33xf32>) -> tensor<33xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<33x256xf32>) outs(%4 : tensor<33xf32>) {
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<33x1024xf32>) outs(%4 : tensor<33xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.addf %in, %out : f32
linalg.yield %6 : f32
Expand All @@ -343,7 +222,7 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// CHECK-LABEL: func.func @group_reduction_larger
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
// CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x256xf32, 3>
// CHECK-DAG: %[[TIDX:.]] = gpu.thread_id x

// Fusion occurred, no barrier before the loop
Expand All @@ -361,10 +240,13 @@ hal.executable.variant public @cuda_nvptx_fb, target = <"cuda", "cuda-nvptx-fb",
// CHECK-DAG: %[[TIDY:.]] = gpu.thread_id y
// Distributed reduction: everyone loads then 5 xor + addf expected.
// CHECK: vector.transfer_read %{{.*}} memref<33xf32>, vector<f32>
// CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
// CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]{{.*}} memref<1x64xf32, 3>, vector<2xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]{{.*}} memref<1x256xf32, 3>, vector<1xf32>
// CHECK-COUNT-5: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf

// CHECK: arith.minui
// CHECK: memref.load
// CHECK-COUNT-3: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
// CHECK: gpu.shuffle idx
// CHECK: %[[RES:.*]] = arith.addf %{{.*}}
// CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
// CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
Expand Down
Loading

0 comments on commit 52643c3

Please sign in to comment.