diff --git a/tuner/examples/test/README_.md b/tuner/examples/test/README_.md new file mode 100644 index 000000000..14669602f --- /dev/null +++ b/tuner/examples/test/README_.md @@ -0,0 +1,32 @@ +# Example Tuner Test + +Example of tuning a dispatch and full model. + +## Environments +Follow instructions in [`/tuner/README.md`](../README.md) + +## Running the Tuner + +### Choose a model to tune +This example uses the simple `double_mmt.mlir` file. + +### Generate a benchmark file +Use the usual `iree-compile` command for your model and add +`--iree-hal-dump-executable-files-to=dump`. For example: +```shell +iree-compile double_mmt.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=dump -o /dev/null +``` + +Next, copy the `*_benchmark.mlir` file to some temporary directory of choice. +This will be the input to the dispatch tuner. In the example, the `mmt_benchmark.mlir` example file (from double_mmt.mlir) can be used. + +### Recommended Trial Run +For an initial trial to test the tuning loop, use: +```shell +python -m examples.test double_mmt.mlir mmt_benchmark.mlir --num-candidates=20 +``` + +### Basic Usage +```shell +python -m examples.test double_mmt.mlir mmt_benchmark.mlir +``` diff --git a/tuner/examples/test/__init__.py b/tuner/examples/test/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/examples/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/examples/test/__main__.py b/tuner/examples/test/__main__.py new file mode 100644 index 000000000..4f426e110 --- /dev/null +++ b/tuner/examples/test/__main__.py @@ -0,0 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import tuner_test + +tuner_test.main() diff --git a/tuner/examples/test/conv_benchmark.mlir b/tuner/examples/test/conv_benchmark.mlir new file mode 100644 index 000000000..7e8533a95 --- /dev/null +++ b/tuner/examples/test/conv_benchmark.mlir @@ -0,0 +1,68 @@ +module { + util.global private @__device_0 = #hal.device.target<"hip", {legacy_sync}, [#hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target, , , , , , , , , , , , , , , ], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>]> : !hal.device + hal.executable private @main_0_dispatch_0 { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target, , , , , , , , , , , , , , , ], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) { + hal.executable.export public @main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32 ordinal(0) layout(#hal.pipeline.layout, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32() attributes {translation_info = #iree_codegen.translation_info} { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 34, 34, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x34x34x1280xi8> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1280, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<3x3x1280x1280xi8> + %5 = tensor.empty() : tensor<2x32x32x1280xi32> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x32x32x1280xi32>) -> tensor<2x32x32x1280xi32> + %7 = linalg.conv_2d_nhwc_hwcf {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 0, 0, 1, 1, 64], subgroup_m_count = 1 : i64, subgroup_n_count = 4 : i64, workgroup = [1, 1, 32, 256, 0, 0, 0]}>, root_op} ins(%3, %4 : tensor<2x34x34x1280xi8>, tensor<3x3x1280x1280xi8>) outs(%6 : tensor<2x32x32x1280xi32>) -> tensor<2x32x32x1280xi32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 32, 32, 1280], strides = [1, 1, 1, 1] : tensor<2x32x32x1280xi32> -> !flow.dispatch.tensor> + return + } + } + } + } + util.global private mutable @main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer + util.initializer { + %c28190720 = arith.constant 28190720 : index + %device, %queue_affinity = hal.device.resolve on(<@__device_0>) : !hal.device, i64 + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %buffer = hal.allocator.allocate<%allocator : !hal.allocator> affinity(%queue_affinity) type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage") : !hal.buffer{%c28190720} + util.global.store %buffer, @main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer + util.return + } + util.func public @main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32(%arg0: i32) attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} { + %c-1_i32 = arith.constant -1 : i32 + %0 = util.null : !hal.fence + %c1 = arith.constant 1 : index + %c10485760 = arith.constant 10485760 : index + %c17704960 = arith.constant 17704960 : index + %c14745600 = arith.constant 14745600 : index + %c2959360 = arith.constant 2959360 : index + %c0 = arith.constant 0 : index + %1 = arith.index_cast %arg0 : i32 to index + %device, %queue_affinity = hal.device.resolve on(<@__device_0>) : !hal.device, i64 + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot|AllowInlineExecution") categories(Dispatch) affinity(%queue_affinity) : !hal.command_buffer + %main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer = util.global.load @main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer + %workgroup_x, %workgroup_y, %workgroup_z = hal.executable.calculate_workgroups device(%device : !hal.device) target(@main_0_dispatch_0::@rocm_hsaco_fb::@main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32) : index, index, index + %exe = hal.executable.lookup device(%device : !hal.device) executable(@main_0_dispatch_0) : !hal.executable + %ordinal = hal.executable.export.ordinal target(@main_0_dispatch_0::@rocm_hsaco_fb::@main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32) : index + scf.for %arg1 = %c0 to %1 step %c1 { + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[%ordinal] workgroups([%workgroup_x, %workgroup_y, %workgroup_z]) bindings([ + (%main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer)[%c0, %c2959360], + (%main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer)[%c2959360, %c14745600], + (%main_0_dispatch_0_rocm_hsaco_fb_main_0_dispatch_0_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280_i8xi8xi32_buffer : !hal.buffer)[%c17704960, %c10485760] + ]) flags("None") + hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None") + } + hal.command_buffer.finalize<%cmd : !hal.command_buffer> + %fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence + hal.device.queue.execute<%device : !hal.device> affinity(%queue_affinity) wait(%0) signal(%fence) commands([%cmd]) + %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32 + util.status.check_ok %status, "failed to wait on timepoint" + util.return + } +} diff --git a/tuner/examples/test/conv_nhwc.mlir b/tuner/examples/test/conv_nhwc.mlir new file mode 100644 index 000000000..22b4a73ec --- /dev/null +++ b/tuner/examples/test/conv_nhwc.mlir @@ -0,0 +1,11 @@ +!convA_0 = tensor<2x34x34x1280xi8> +!convB_0 = tensor<3x3x1280x1280xi8> +!convC_0 = tensor<2x32x32x1280xi32> + +func.func @main_0(%arg0: !convA_0, %arg1: !convB_0) -> !convC_0 { + %cst = arith.constant 0.000000e+00 : f16 + %5 = tensor.empty() : !convC_0 + %6 = linalg.fill ins(%cst : f16) outs(%5 : !convC_0) -> !convC_0 + %8 = linalg.conv_2d_nhwc_hwcf ins(%arg0, %arg1 : !convA_0, !convB_0) outs(%6 : !convC_0) -> !convC_0 + return %8 : !convC_0 +} diff --git a/tuner/examples/test/double_mmt.mlir b/tuner/examples/test/double_mmt.mlir new file mode 100644 index 000000000..a3bd4c7b0 --- /dev/null +++ b/tuner/examples/test/double_mmt.mlir @@ -0,0 +1,16 @@ +!matA_0 = tensor<2048x2048xf16> +!matB_0 = tensor<2048x2048xf16> +!matC_0 = tensor<2048x2048xf32> + +!matC_1 = tensor<2048x2048xf32> + +func.func @main(%arg0: !matA_0, %arg1: !matB_0) -> !matC_1 { + %cst = arith.constant 0.000000e+00 : f32 + %5 = tensor.empty() : !matC_0 + %6 = linalg.fill ins(%cst : f32) outs(%5 : !matC_0) -> !matC_0 + %7 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0 + %8 = tensor.empty() : !matC_1 + %9 = linalg.fill ins(%cst : f32) outs(%8 : !matC_1) -> !matC_1 + %10 = linalg.matmul_transpose_b ins(%7, %7 : !matC_0, !matC_0) outs(%9 : !matC_1) -> !matC_1 + return %10 : !matC_1 +} diff --git a/tuner/examples/test/mmt_benchmark.mlir b/tuner/examples/test/mmt_benchmark.mlir new file mode 100644 index 000000000..a3cc95b95 --- /dev/null +++ b/tuner/examples/test/mmt_benchmark.mlir @@ -0,0 +1,73 @@ +module { + util.global private @__device_0 = #hal.device.target<"hip", {legacy_sync}, [#hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target, , , , , , , , , , , , , , , ], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>]> : !hal.device + hal.executable private @main_dispatch_0 { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target, , , , , , , , , , , , , , , ], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) { + hal.executable.export public @main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32 ordinal(0) layout(#hal.pipeline.layout, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32() attributes {translation_info = #iree_codegen.translation_info} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<2048x2048xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<2048x2048xf16> + %5 = tensor.empty() : tensor<2048x2048xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32> + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<2048x2048xf16>, tensor<2048x2048xf16>) outs(%6 : tensor<2048x2048xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 64], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 128, 0]}>, root_op} { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %8 = arith.extf %in : f16 to f32 + %9 = arith.extf %in_0 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %out, %10 : f32 + linalg.yield %11 : f32 + } -> tensor<2048x2048xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] : tensor<2048x2048xf32> -> !flow.dispatch.tensor> + return + } + } + } + } + util.global private mutable @main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer + util.initializer { + %c33554432 = arith.constant 33554432 : index + %device, %queue_affinity = hal.device.resolve on(<@__device_0>) : !hal.device, i64 + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %buffer = hal.allocator.allocate<%allocator : !hal.allocator> affinity(%queue_affinity) type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage") : !hal.buffer{%c33554432} + util.global.store %buffer, @main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer + util.return + } + util.func public @main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32(%arg0: i32) attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} { + %c-1_i32 = arith.constant -1 : i32 + %0 = util.null : !hal.fence + %c1 = arith.constant 1 : index + %c16777216 = arith.constant 16777216 : index + %c8388608 = arith.constant 8388608 : index + %c0 = arith.constant 0 : index + %1 = arith.index_cast %arg0 : i32 to index + %device, %queue_affinity = hal.device.resolve on(<@__device_0>) : !hal.device, i64 + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot|AllowInlineExecution") categories(Dispatch) affinity(%queue_affinity) : !hal.command_buffer + %main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer = util.global.load @main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer + %workgroup_x, %workgroup_y, %workgroup_z = hal.executable.calculate_workgroups device(%device : !hal.device) target(@main_dispatch_0::@rocm_hsaco_fb::@main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32) : index, index, index + %exe = hal.executable.lookup device(%device : !hal.device) executable(@main_dispatch_0) : !hal.executable + %ordinal = hal.executable.export.ordinal target(@main_dispatch_0::@rocm_hsaco_fb::@main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32) : index + scf.for %arg1 = %c0 to %1 step %c1 { + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[%ordinal] workgroups([%workgroup_x, %workgroup_y, %workgroup_z]) bindings([ + (%main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer)[%c0, %c8388608], + (%main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer)[%c8388608, %c8388608], + (%main_dispatch_0_rocm_hsaco_fb_main_dispatch_0_matmul_transpose_b_2048x2048x2048_f16xf16xf32_buffer : !hal.buffer)[%c16777216, %c16777216] + ]) flags("None") + hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None") + } + hal.command_buffer.finalize<%cmd : !hal.command_buffer> + %fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence + hal.device.queue.execute<%device : !hal.device> affinity(%queue_affinity) wait(%0) signal(%fence) commands([%cmd]) + %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32 + util.status.check_ok %status, "failed to wait on timepoint" + util.return + } +} diff --git a/tuner/examples/test/tuner_test.py b/tuner/examples/test/tuner_test.py new file mode 100644 index 000000000..d8c35d60b --- /dev/null +++ b/tuner/examples/test/tuner_test.py @@ -0,0 +1,40 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from tuner import libtuner + + +def main(): + args = libtuner.parse_arguments() + + path_config = libtuner.PathConfig() + path_config.base_dir.mkdir(parents=True, exist_ok=True) + path_config.output_unilog.touch() + candidate_trackers: list[libtuner.CandidateTracker] = [] + stop_after_phase: str = args.stop_after + + print("Setup logging") + libtuner.setup_logging(args, path_config) + print(path_config.run_log, end="\n\n") + + if not args.dry_run: + print("Validating devices") + libtuner.validate_devices(args.devices) + print("Validation successful!\n") + + print("Generating candidates...") + candidates = libtuner.generate_candidate_specs( + args, path_config, candidate_trackers + ) + print(f"Stored candidate specs in {path_config.specs_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Check the detailed execution logs in:") + print(path_config.run_log.resolve()) + + for candidate in candidate_trackers: + libtuner.logging.debug(candidate) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index ed150bfec..ed4d63f7d 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -35,6 +35,7 @@ from .common import * from .dispatch_constraints import * from .dispatch_parser import * +from .spec_builder import * tune_logger = logging.getLogger("tune") @@ -106,6 +107,15 @@ def apply_params( """Apply parameter transformations to the operation.""" pass + @abstractmethod + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + """Generate a transform dialect spec that applies the compilation info attr.""" + pass + class DispatchTunerRegistry: def __init__(self): @@ -130,6 +140,68 @@ def find_handler(self, op_name: str) -> DispatchTuner: assert False, "Dispatch kind not supported" +class ContractionOpInterfaceTuner(DispatchTuner, ContractionOpInterfaceParser): + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> MLIRTransformation: + raise NotImplementedError + + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + contraction_op: ir.Operation = self.get_contraction_operation(ir_module) + lhs_type = ir.ShapedType(contraction_op.operands[0].type) + rhs_type = ir.ShapedType(contraction_op.operands[1].type) + acc_type = ir.ShapedType(contraction_op.operands[2].type) + M = acc_type.get_dim_size(0) + N = acc_type.get_dim_size(1) + K = lhs_type.get_dim_size(1) + # TODO(Max191): Get the function name from the func.func in the input module. + func_name = f"match_contraction_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" + return build_td_spec( + ir_module.context, contraction_op, compilation_info, func_name + ) + + +class ConvolutionOpInterfaceTuner(DispatchTuner, ConvolutionOpInterfaceParser): + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> MLIRTransformation: + raise NotImplementedError + + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + conv_op: ir.Operation = self.get_conv_operation(ir_module) + assert ( + conv_op.name == "linalg.conv_2d_nhwc_hwcf" + ), "expected linalg.conv_2d_nhwc_hwcf" + lhs_type = ir.ShapedType(conv_op.operands[0].type) + rhs_type = ir.ShapedType(conv_op.operands[1].type) + acc_type = ir.ShapedType(conv_op.operands[2].type) + N = acc_type.get_dim_size(0) + H = acc_type.get_dim_size(1) + W = acc_type.get_dim_size(2) + C = rhs_type.get_dim_size(2) + P = rhs_type.get_dim_size(0) + Q = rhs_type.get_dim_size(1) + F = rhs_type.get_dim_size(3) + conv_type = conv_op.name.split(".")[-1] + # TODO(Max191): Get the function name from the func.func in the input module. + func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}" + return build_td_spec(ir_module.context, conv_op, compilation_info, func_name) + + class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, @@ -174,6 +246,13 @@ def apply_params( ) return MLIRTransformation(template, modified, embeddable) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + class ConvTuner(DispatchTuner, ConvParser): def get_transform_function_conv( @@ -235,6 +314,13 @@ def apply_params( ) return MLIRTransformation(template, modified, embeddable) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + class ContractionTuner(DispatchTuner, ContractionParser): def get_transform_function_broadcast_rhs_mmt( @@ -306,6 +392,13 @@ def apply_params( "", ) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + class BatchMmtTuner(DispatchTuner, BatchMmtParser): def get_transform_function_batch_mmt( @@ -353,6 +446,13 @@ def apply_params( ) return MLIRTransformation(template, modified, embeddable) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + class BatchMatmulTuner(DispatchTuner, BatchMatmulParser): def get_transform_function_batch_matmul( @@ -409,6 +509,13 @@ def apply_params( ) return MLIRTransformation(template, modified, embeddable) + def get_td_spec( + self, + ir_module: ir.Module, + compilation_info: iree_codegen.CompilationInfoAttr, + ) -> ir.Module: + raise NotImplementedError + @dataclass class OpWalkResult: @@ -452,6 +559,7 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") +# TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove in favor of using tune_with_td. def tune( input: str, # Path to the mlir file to be tuned output: str = "", # Path to the output directory, auto creates one if not given @@ -527,6 +635,53 @@ def tune( tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") +def generate_configs_and_td_specs( + input_module: ir.Module, # Path to the mlir file to be tuned + tuner_context: TunerContext, + limit: int = 4096, # Max candidates to be generated + num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints +) -> list[ir.Module]: + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + ContractionOpInterfaceTuner(), + ConvolutionOpInterfaceTuner(), + ] + ) + + walk_result: OpWalkResult = walk_mlir_op(input_module, dispatch_tuner_registry) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes( + str(input_module).splitlines() + ) + tune_logger.debug(str(problem_size)) + + # Index 0 is reserved for default config, so it gets no td spec. + with ir.Location.unknown() as loc: + empty_module = ir.Module.create(loc) + config_specs: list[ir.Module] = [empty_module] + + # Get the MMA intrinisic intructions supported by the target. + variant_op_list = iree_codegen.get_executable_variant_ops(input_module) + assert len(variant_op_list) == 1, "Expect one executable variant op" + variant_op = variant_op_list[0] + mma_list = iree_codegen.query_mma_intrinsics(variant_op) + for i, config in enumerate( + generate_solutions(tuner_context, problem_size, num_subgroups, mma_list) + ): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + td_spec_module = dispatch_tuner.get_td_spec(input_module, config) + assert td_spec_module, "Failed to generate transform dialect spec" + config_specs.append(td_spec_module) + + tune_logger.info(f"Generated {len(config_specs)} tuning specs") + return config_specs + + def main(): parser = argparse.ArgumentParser() parser.add_argument("input", help="Input mlir file", type=str) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 0428ab7d2..d135a8502 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -15,9 +15,11 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore from iree.compiler.dialects import iree_codegen # type: ignore +from iree.compiler.dialects import transform # type: ignore from . import candidate_gen from . import common +from . import op_matchers @pytest.fixture @@ -36,6 +38,183 @@ def remove_comments(mlir: str) -> str: ) +def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + module_str = """ + builtin.module{ + func.func @test(%arg0: tensor<2048x2048xf16>, %arg1: tensor<2048x2048xf16>) -> tensor<2048x2048xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<2048x2048xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32> + %2 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + {root_op} + ins(%arg0, %arg1 : tensor<2048x2048xf16>, tensor<2048x2048xf16>) + outs(%1 : tensor<2048x2048xf32>) { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %3 = arith.extf %in : f16 to f32 + %4 = arith.extf %in_0 : f16 to f32 + %5 = arith.mulf %3, %4 : f32 + %6 = arith.addf %out, %5 : f32 + linalg.yield %6 : f32 + } -> tensor<2048x2048xf32> + return %2 : tensor<2048x2048xf32> + } + }""" + + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[8, 8, 0], + reduction=[0, 0, 8], + subgroup_m_count=16, + subgroup_n_count=16, + ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=True) + config_dict = common.get_translation_info_config(pipeline_options, waves_per_eu=8) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [16, 16, 1], 16, config_dict + ) + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) + + ir_module = ir.Module.parse(module_str, context) + + tuner = candidate_gen.ContractionOpInterfaceTuner() + td_spec_module = tuner.get_td_spec(ir_module, compilation_info) + assert td_spec_module + + named_sequence_ops: list[ + transform.NamedSequenceOp + ] = op_matchers.get_ops_from_module( + module=td_spec_module, + fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp), + ) + apply_config_sequence = None + matcher_sequence = None + entry_point = None + for op in named_sequence_ops: + if str(op.opview.sym_name) == '"apply_op_config"': + apply_config_sequence = op + elif str(op.opview.sym_name) == '"__kernel_config"': + entry_point = op + else: + matcher_sequence = op + + assert apply_config_sequence + assert matcher_sequence + assert entry_point + matcher_sequence_str = str(matcher_sequence) + + assert ( + "mma_kind = #iree_gpu.mma_layout" in matcher_sequence_str + ) + assert "subgroup_m_count = 16" in matcher_sequence_str + assert "subgroup_n_count = 16" in matcher_sequence_str + assert "pipeline = LLVMGPUVectorDistribute" in matcher_sequence_str + assert "workgroup_size = [16, 16, 1]" in matcher_sequence_str + assert "subgroup_size = 16" in matcher_sequence_str + assert "workgroup = [8, 8, 0]" in matcher_sequence_str + assert "reduction = [0, 0, 8]" in matcher_sequence_str + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in matcher_sequence_str + ) + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in matcher_sequence_str + + +def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + module_str = """ + builtin.module{ + func.func @test(%arg0: tensor<2x34x34x2048xi8>, %arg1: tensor<3x3x2048x2048xi8>) -> tensor<2x32x32x2048xi32> { + %cst = arith.constant 0 : i32 + %0 = tensor.empty() : tensor<2x32x32x2048xi32> + %1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x2048xi32>) -> tensor<2x32x32x2048xi32> + %2 = linalg.conv_2d_nhwc_hwcf {root_op} + ins(%arg0, %arg1 : tensor<2x34x34x2048xi8>, tensor<3x3x2048x2048xi8>) + outs(%1 : tensor<2x32x32x2048xi32>) -> tensor<2x32x32x2048xi32> + return %2 : tensor<2x32x32x2048xi32> + } + }""" + + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config = common.get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[1, 1, 464, 320, 0, 0, 0], + reduction=[0, 0, 0, 0, 1, 1, 16], + subgroup_m_count=1, + subgroup_n_count=4, + ) + pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( + iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + ) + pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=False) + config_dict = common.get_translation_info_config(pipeline_options, waves_per_eu=2) + translation_info = iree_codegen.TranslationInfoAttr.get( + pipeline_attr, None, [256, 1, 1], 64, config_dict + ) + compilation_info = iree_codegen.CompilationInfoAttr.get( + lowering_config, translation_info + ) + + ir_module = ir.Module.parse(module_str, context) + + tuner = candidate_gen.ConvolutionOpInterfaceTuner() + td_spec_module = tuner.get_td_spec(ir_module, compilation_info) + assert td_spec_module + + named_sequence_ops: list[ + transform.NamedSequenceOp + ] = op_matchers.get_ops_from_module( + module=td_spec_module, + fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp), + ) + apply_config_sequence = None + matcher_sequence = None + entry_point = None + for op in named_sequence_ops: + if str(op.opview.sym_name) == '"apply_op_config"': + apply_config_sequence = op + elif str(op.opview.sym_name) == '"__kernel_config"': + entry_point = op + else: + matcher_sequence = op + + assert apply_config_sequence + assert matcher_sequence + assert entry_point + + matcher_sequence_str = str(matcher_sequence) + + assert ( + "mma_kind = #iree_gpu.mma_layout" in matcher_sequence_str + ) + assert "subgroup_m_count = 1" in matcher_sequence_str + assert "subgroup_n_count = 4" in matcher_sequence_str + assert "pipeline = LLVMGPUVectorDistribute" in matcher_sequence_str + assert "workgroup_size = [256, 1, 1]" in matcher_sequence_str + assert "subgroup_size = 64" in matcher_sequence_str + assert "workgroup = [1, 1, 464, 320, 0, 0, 0]" in matcher_sequence_str + assert "reduction = [0, 0, 0, 0, 1, 1, 16]" in matcher_sequence_str + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in matcher_sequence_str + ) + + def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 5c79bd8dd..78e3a8e9d 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -78,6 +78,14 @@ class MatmulSize: B: int = 1 +@dataclass +class ContractionDimensions: + batch: list[int] + m: list[int] + n: list[int] + k: list[int] + + @dataclass class ProblemSize: matmul_size: MatmulSize @@ -98,13 +106,12 @@ def get_compatible_mfma_intrinsics( def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma a_type, b_type, c_type = mma_attr.abc_element_types - if problem_size.res_type.element_type != c_type: + if not isinstance(problem_size.res_type.element_type, type(c_type)): return False if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if ( - problem_size.lhs_type.element_type != a_type - or problem_size.rhs_type.element_type != b_type - ): + if not isinstance( + problem_size.lhs_type.element_type, type(a_type) + ) or not isinstance(problem_size.rhs_type.element_type, type(b_type)): return False return True diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 797c83534..914c04bbf 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -157,9 +157,9 @@ def getMMAAttr( a_type, b_type, c_type = mma_attr.abc_element_types mnk = mma_attr.mnk_shape if ( - a_type == lhs_type - and b_type == rhs_type - and c_type == output_type + isinstance(a_type, type(lhs_type)) + and isinstance(b_type, type(rhs_type)) + and isinstance(c_type, type(output_type)) and m == mnk[0] and n == mnk[1] and k == mnk[2] diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index fe95c52a6..e3fca244b 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -11,6 +11,7 @@ import re from abc import ABCMeta, abstractmethod +from .op_matchers import * from .common import * @@ -89,6 +90,101 @@ def get_shapes(self, template: list[str]) -> ProblemSize: pass +# TODO(Max191): Support linalg named op versions of contraction ops. The +# current matchers only work for linalg.generic ops. +class ContractionOpInterfaceParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return ( + "matmul_like" in op_name + or "batch_matmul" in op_name + or "batch_matmul_transpose_b" in op_name + or "matmul_transpose_b" in op_name + ) + + def get_contraction_operation( + self, + ir_module: ir.Module, + ) -> Optional[ir.Operation]: + return match_root_op(ir_module, ContractionOpInterfaceMatcher()) + + # TODO(Max191): Pass the ir_module directly instead of the template str. + def get_shapes(self, template: list[str]) -> ProblemSize: + matcher = ContractionOpInterfaceMatcher() + with ir.Context() as ctx: + ir_module = ir.Module.parse("\n".join(template), ctx) + contraction_op = match_root_op(ir_module, matcher) + if contraction_op is None: + assert False, f"contraction op not found" + cdims = matcher.contraction_dimensions + assert cdims, "no contraction dimensions" + assert matcher.lhs_dims, "no lhs dimensions" + assert matcher.rhs_dims, "no rhs dimensions" + assert matcher.res_dims, "no result dimensions" + assert len(cdims.batch) <= 1, f"must have at most 1 batch dimension" + assert len(cdims.m) == 1, f"must have a single m dimension" + assert len(cdims.n) == 1, f"must have a single n dimension" + assert len(cdims.k) == 1, f"must have a single k dimension" + lhs_type = ir.RankedTensorType(contraction_op.operands[0].type) + rhs_type = ir.RankedTensorType(contraction_op.operands[1].type) + res_type = ir.RankedTensorType(contraction_op.operands[2].type) + matmul_size = MatmulSize( + lhs_type.shape[matcher.lhs_dims.index(cdims.m[0])], + rhs_type.shape[matcher.rhs_dims.index(cdims.n[0])], + lhs_type.shape[matcher.lhs_dims.index(cdims.k[0])], + ) + if len(cdims.batch) == 1: + matmul_size.B = lhs_type.shape[matcher.lhs_dims.index(cdims.batch[0])] + return ProblemSize( + matmul_size, + lhs_type=ShapedType(lhs_type.shape, lhs_type.element_type), + rhs_type=ShapedType(rhs_type.shape, rhs_type.element_type), + res_type=ShapedType(res_type.shape, res_type.element_type), + dispatch_kind=DispatchKind.contraction, + ) + + +# TODO(Max191): Support more convolution types. Only NHWC convs are supported. +class ConvolutionOpInterfaceParser(DispatchParser): + def __init__(self): + self.supported_ops = ["linalg.conv_2d_nhwc_hwcf"] + + def supports(self, op_name: str) -> bool: + for supported_op_name in self.supported_ops: + if supported_op_name.split(".")[-1] in op_name: + return True + return False + + def get_conv_operation( + self, + ir_module: ir.Module, + ) -> Optional[ir.Operation]: + return match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) + + # TODO(Max191): Pass the ir_module directly instead of the template str. + def get_shapes(self, template: list[str]) -> ProblemSize: + with ir.Context() as ctx: + ir_module = ir.Module.parse("\n".join(template), ctx) + conv_op = match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) + if conv_op is None: + assert False, f"convolution op not found" + lhs_type = ir.RankedTensorType(conv_op.operands[0].type) + rhs_type = ir.RankedTensorType(conv_op.operands[1].type) + res_type = ir.RankedTensorType(conv_op.operands[2].type) + dim_info = ConvDimInfo.from_rhs_res(rhs_type, res_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_type=ShapedType(lhs_type.shape, lhs_type.element_type), + rhs_type=ShapedType(rhs_type.shape, rhs_type.element_type), + res_type=ShapedType(res_type.shape, res_type.element_type), + dispatch_kind=DispatchKind.conv, + ) + + class MmtParser(DispatchParser): def supports(self, op_name: str) -> bool: return "matmul_transpose_b" in op_name diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 9f4afbb19..0b87be659 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -16,6 +16,7 @@ from iree.compiler.dialects import func # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore from iree.compiler.dialects import iree_codegen # type: ignore +from iree.compiler.dialects import linalg # type: ignore from . import common from . import dispatch_parser @@ -40,6 +41,103 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: ) +CONTRACTION_TEMPLATE = r""" +builtin.module{{ + func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : {res_type} + %1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type} + %2 = linalg.generic {{ + indexing_maps = [ + {lhs_map}, + {rhs_map}, + {res_map}], + iterator_types = {iterator_types}}} + {{root_op}} + ins(%arg0, %arg1 : {lhs_type}, {rhs_type}) + outs(%1 : {res_type}) {{ + ^bb0(%in: f16, %in_0: f16, %out: f32): + %3 = arith.extf %in : f16 to f32 + %4 = arith.extf %in_0 : f16 to f32 + %5 = arith.mulf %3, %4 : f32 + %6 = arith.addf %out, %5 : f32 + linalg.yield %6 : f32 + }} -> {res_type} + return %2 : {res_type} + }} +}}""" + + +def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + + with ir.Location.unknown(): + transpose_b_str = CONTRACTION_TEMPLATE.format( + lhs_type=ir.RankedTensorType.get([16, 64], ir.F16Type.get()), + rhs_type=ir.RankedTensorType.get([32, 64], ir.F16Type.get()), + res_type=ir.RankedTensorType.get([16, 32], ir.F32Type.get()), + lhs_map="affine_map<(d0, d1, d2) -> (d0, d2)>", + rhs_map="affine_map<(d0, d1, d2) -> (d1, d2)>", + res_map="affine_map<(d0, d1, d2) -> (d0, d1)>", + iterator_types='["parallel", "parallel", "reduction"]', + ) + module = ir.Module.parse(transpose_b_str, context) + parser = dispatch_parser.ContractionOpInterfaceParser() + mmt_op = parser.get_contraction_operation(module) + assert mmt_op is not None + assert isinstance(mmt_op.opview, linalg.GenericOp) + shapes: common.ProblemSize = parser.get_shapes(transpose_b_str.splitlines()) + assert shapes.matmul_size.B == 1 + assert shapes.matmul_size.M == 16 + assert shapes.matmul_size.N == 32 + assert shapes.matmul_size.K == 64 + assert shapes.lhs_type.shape == [16, 64] + assert isinstance(shapes.lhs_type.element_type, ir.F16Type) + assert shapes.rhs_type.shape == [32, 64] + assert isinstance(shapes.rhs_type.element_type, ir.F16Type) + assert shapes.res_type.shape == [16, 32] + assert isinstance(shapes.res_type.element_type, ir.F32Type) + + with ir.Location.unknown(): + bmm_transposed_inputs_str = CONTRACTION_TEMPLATE.format( + lhs_type=ir.RankedTensorType.get([5, 8, 128], ir.F16Type.get()), + rhs_type=ir.RankedTensorType.get([128, 40, 5], ir.F16Type.get()), + res_type=ir.RankedTensorType.get([5, 40, 8], ir.F32Type.get()), + lhs_map="affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>", + rhs_map="affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>", + res_map="affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>", + iterator_types='["parallel", "parallel", "parallel", "reduction"]', + ) + module = ir.Module.parse(bmm_transposed_inputs_str, context) + mmt_op = parser.get_contraction_operation(module) + shapes = parser.get_shapes(bmm_transposed_inputs_str.splitlines()) + assert shapes.matmul_size.B == 5 + assert shapes.matmul_size.M == 8 + assert shapes.matmul_size.N == 40 + assert shapes.matmul_size.K == 128 + + +def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + module_str = """ + builtin.module{ + func.func @test(%arg0: tensor<2x34x34x16xi8>, %arg1: tensor<3x3x16x16xi8>) -> tensor<2x32x32x16xi32> { + %cst = arith.constant 0 : i32 + %0 = tensor.empty() : tensor<2x32x32x16xi32> + %1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32> + %2 = linalg.conv_2d_nhwc_hwcf {root_op} + ins(%arg0, %arg1 : tensor<2x34x34x16xi8>, tensor<3x3x16x16xi8>) + outs(%1 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32> + return %2 : tensor<2x32x32x16xi32> + } + }""" + module = ir.Module.parse(module_str, context) + parser = dispatch_parser.ConvolutionOpInterfaceParser() + conv_op = parser.get_conv_operation(module) + assert conv_op is not None + assert isinstance(conv_op.opview, linalg.Conv2DNhwcHwcfOp) + + def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 3c195520c..6bece17f4 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -39,7 +39,10 @@ import json from abc import ABC, abstractmethod import iree.runtime as ireert # type: ignore +from iree.compiler import ir # type: ignore from . import candidate_gen +from . import dispatch_parser +from .common import * # Default values for num_candidates and devices, change it as needed @@ -62,6 +65,7 @@ @dataclass class CandidateTracker: candidate_id: int + mlir_path: Optional[Path] = None dispatch_mlir_path: Optional[Path] = None dispatch_config_path: Optional[Path] = None configuration: Optional[candidate_gen.iree_codegen.CompilationInfoAttr] = None @@ -746,6 +750,7 @@ def append_to_file(lines: list[str], filepath: Path, title: str = "") -> None: file.write("\n") +# TODO(Max191): Remove in favor of using generate_candidate_specs. def generate_candidates( args: argparse.Namespace, path_config: PathConfig, @@ -825,6 +830,66 @@ def generate_candidates( return candidates +def generate_candidate_specs( + args: argparse.Namespace, + path_config: PathConfig, + candidate_trackers: list[CandidateTracker], +) -> list[int]: + """Generate candidate transform dialect specs for tuning. Returns the list of candidate indexes""" + logging.debug("generate_candidate_specs()") + + path_config.specs_dir.mkdir(parents=True, exist_ok=True) + tune_logger = logging.getLogger("tune") + + # Generate transform dialect specs. + try: + with open(args.input_file, "r") as f: + mlir_text = f.read() + with ir.Context() as ctx: + tuner_context = TunerContext(ctx, tune_logger) + mlir_module = dispatch_parser.parse_mlir(mlir_text, tuner_context) + logging.debug("Captured messages from candidate_gen.py:") + config_specs: list[ir.Module] = candidate_gen.generate_configs_and_td_specs( + input_module=mlir_module, + tuner_context=tuner_context, + limit=args.num_candidates, + num_subgroups=args.num_subgroups, + ) + logging.debug("candidate_gen.py ends") + handle_error( + condition=(len(config_specs) <= 1), msg="Failed to generate any candidates" + ) + + # Create candidate trackers. + candidates = [] + for candidate_num, spec in enumerate(config_specs): + candidates.append(candidate_num) + # Move the specs to the canonical path_config location. + spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename( + candidate_num + ) + with open(spec_path, "w") as f: + f.write(str(spec)) + new_candidate = CandidateTracker( + mlir_path=args.input_file, + candidate_id=candidate_num, + spec_path=spec_path, + ) + candidate_trackers.append(new_candidate) + except Exception as e: + logging.error("An error occurred during candidates generation: %s", str(e)) + # Capture and log debug messages from candidate_gen.py. + tune_logger = logging.getLogger("tune_with_td") + for handler in logging.getLogger().handlers: + if isinstance(handler, logging.FileHandler): + tune_logger.handlers.append(handler) + tune_logger.exception("Error in candidate_gen.py:") + raise + + logging.info(f"Generated [{len(candidates) - 1}] candidates") + return candidates + + def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, list[int]]: """If a collision is found, generate a list of new indexes. If no collision, `unique_indexes = []`""" # Check if candidate produces tbe same .vmfb diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py new file mode 100644 index 000000000..1abdafd3d --- /dev/null +++ b/tuner/tuner/op_matchers.py @@ -0,0 +1,178 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This code implements matcher functions for MLIR modules using python bindings. + +from abc import ABCMeta, abstractmethod + +from .common import * +from iree.compiler import ir # type: ignore + + +class OpMatcher(metaclass=ABCMeta): + @abstractmethod + def match(self, op: ir.Operation) -> bool: + """Check if the op passes the matching criteria.""" + pass + + +def walk_collect_ops( + op: ir.Operation, + ops: list[ir.Operation], + fn, +) -> ir.WalkResult: + if fn(op): + ops.append(op) + return ir.WalkResult.ADVANCE + + +def get_ops_from_module(module: ir.Module, fn): + ops: list[ir.Operation] = [] + for op in module.body.operations: + op.walk( + lambda op: walk_collect_ops(op, ops, fn), + ir.WalkOrder.POST_ORDER, + ) + return ops + + +def is_root_op(op: ir.Operation) -> bool: + for attr in op.opview.attributes: + if attr.name == "root_op": + return True + return False + + +def match_root_op( + ir_module: ir.Module, + matcher: OpMatcher, +) -> Optional[ir.Operation]: + root_ops: list[ir.Operation] = get_ops_from_module(ir_module, is_root_op) + if len(root_ops) != 1: + return None + if not matcher.match(root_ops[0].operation): + return None + return root_ops[0] + + +class NamedOpMatcher(OpMatcher): + def __init__(self, op_names: list[str]): + self.op_names = op_names + + def match(self, op: ir.Operation) -> bool: + return op.name in self.op_names + + +# TODO(Max191): Add logic to match the body of the generic op. +class GenericOpMatcher(NamedOpMatcher): + def __init__(self): + super().__init__(["linalg.generic"]) + + @abstractmethod + def match_operands(self, operands: ir.OpOperandList) -> bool: + """Match the operands of the linalg op.""" + pass + + @abstractmethod + def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool: + """Match the indexing_maps of the linalg op.""" + pass + + def match(self, op: ir.Operation) -> bool: + if not super().match(op): + return False + + if not self.match_operands(op.operands): + return False + + maps_attr = None + for attr in op.opview.attributes: + if attr.name == "indexing_maps" and isinstance(attr.attr, ir.ArrayAttr): + maps_attr = attr.attr + if maps_attr is None: + return False + + maps: list[ir.AffineMap] = [] + for map in maps_attr: + maps.append(map.value) + if not self.match_indexing_maps(maps): + return False + + return True + + +def get_map_result_dim_positions(map: ir.AffineMap): + exprs = [] + if not map.is_projected_permutation: + return None + for expr in map.results: + dim_str = str(expr) + if len(dim_str) < 1: + return None + if dim_str[0] != "d": + return None + if not dim_str[1:].isdigit(): + return None + dim_position = int(dim_str[1:]) + exprs.append(dim_position) + return exprs + + +class ContractionOpInterfaceMatcher(GenericOpMatcher): + def __init__(self): + super().__init__() + self.contraction_dimensions: Optional[ContractionDimensions] = None + self.lhs_dims: Optional[list[int]] = None + self.rhs_dims: Optional[list[int]] = None + self.res_dims: Optional[list[int]] = None + + def match_operands(self, operands: ir.OpOperandList) -> bool: + if len(operands) != 3: + return False + for operand in operands: + if not isinstance(operand.type, ir.ShapedType): + return False + return True + + def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool: + if len(maps) != 3: + return False + lhs_dims = get_map_result_dim_positions(maps[0]) + rhs_dims = get_map_result_dim_positions(maps[1]) + res_dims = get_map_result_dim_positions(maps[2]) + if lhs_dims is None or rhs_dims is None or res_dims is None: + return False + + batch_dims = [] + m_dims = [] + n_dims = [] + k_dims = [] + + for d in range(maps[0].n_dims): + if d in lhs_dims and d in rhs_dims and d in res_dims: + batch_dims.append(d) + continue + if d in lhs_dims and d in res_dims: + m_dims.append(d) + continue + if d in rhs_dims and d in res_dims: + n_dims.append(d) + continue + if d in lhs_dims and d in rhs_dims: + k_dims.append(d) + continue + return False + + self.contraction_dimensions = ContractionDimensions( + batch=batch_dims, + m=m_dims, + n=n_dims, + k=k_dims, + ) + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.res_dims = res_dims + return True diff --git a/tuner/tuner/spec_builder.py b/tuner/tuner/spec_builder.py new file mode 100644 index 000000000..a27bd072f --- /dev/null +++ b/tuner/tuner/spec_builder.py @@ -0,0 +1,62 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore + +from .common import * +from .dispatch_constraints import * +from .dispatch_parser import * + + +# TODO(Max191): Use python bindings to build the transform dialect spec module +# instead of using string formatting. +def build_td_spec( + context: ir.Context, + op: ir.Operation, + compilation_info: iree_codegen.CompilationInfoAttr, + func_name: str, +) -> ir.Module: + bbargs = [] + for operand in op.operands: + ssa_name = operand.get_name() + operand_type = operand.type + bbargs.append(f"{ssa_name}: {operand_type}") + bbargs_str = ", ".join(bbargs) + root_operation = str(op) + spec_text = f""" + module attributes {{ transform.with_named_sequence }} {{ + // Annotation Transform + transform.named_sequence @apply_op_config(%op: !transform.any_op {{transform.readonly}}, + %config: !transform.any_param {{transform.readonly}}) {{ + transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param + transform.yield + }} + + // Custom Op Matcher + transform.named_sequence @{func_name}(%cont: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %cont {{ + ^bb0({bbargs_str}): + {root_operation} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant {compilation_info} -> !transform.any_param + transform.yield %cont, %config : !transform.any_op, !transform.any_param + }} + + // Entry Point + transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {{transform.consumed}}) {{ + transform.foreach_match in %variant_op + @{func_name} -> @apply_op_config + : (!transform.any_op) -> (!transform.any_op) + transform.yield + }} + }} + """ + return ir.Module.parse(spec_text, context)