Skip to content

Commit

Permalink
address comments, use root_op attribute in matching
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Dec 12, 2024
1 parent 52d04e2 commit 119833b
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 61 deletions.
2 changes: 1 addition & 1 deletion tuner/examples/test/conv_benchmark.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ module {
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1280, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x1280x1280xi8>> -> tensor<3x3x1280x1280xi8>
%5 = tensor.empty() : tensor<2x32x32x1280xi32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x32x32x1280xi32>) -> tensor<2x32x32x1280xi32>
%7 = linalg.conv_2d_nhwc_hwcf {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, promote_operands = [0, 1], reduction = [0, 0, 0, 0, 1, 1, 64], subgroup_m_count = 1 : i64, subgroup_n_count = 4 : i64, workgroup = [1, 1, 32, 256, 0, 0, 0]}>} ins(%3, %4 : tensor<2x34x34x1280xi8>, tensor<3x3x1280x1280xi8>) outs(%6 : tensor<2x32x32x1280xi32>) -> tensor<2x32x32x1280xi32>
%7 = linalg.conv_2d_nhwc_hwcf {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, promote_operands = [0, 1], reduction = [0, 0, 0, 0, 1, 1, 64], subgroup_m_count = 1 : i64, subgroup_n_count = 4 : i64, workgroup = [1, 1, 32, 256, 0, 0, 0]}>, root_op} ins(%3, %4 : tensor<2x34x34x1280xi8>, tensor<3x3x1280x1280xi8>) outs(%6 : tensor<2x32x32x1280xi32>) -> tensor<2x32x32x1280xi32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 32, 32, 1280], strides = [1, 1, 1, 1] : tensor<2x32x32x1280xi32> -> !flow.dispatch.tensor<writeonly:tensor<2x32x32x1280xi32>>
return
}
Expand Down
2 changes: 1 addition & 1 deletion tuner/examples/test/mmt_benchmark.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ module {
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2048, 2048], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2048x2048xf16>> -> tensor<2048x2048xf16>
%5 = tensor.empty() : tensor<2048x2048xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<2048x2048xf16>, tensor<2048x2048xf16>) outs(%6 : tensor<2048x2048xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [0, 1], reduction = [0, 0, 64], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 128, 0]}>} {
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<2048x2048xf16>, tensor<2048x2048xf16>) outs(%6 : tensor<2048x2048xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [0, 1], reduction = [0, 0, 64], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 128, 0]}>, root_op} {
^bb0(%in: f16, %in_0: f16, %out: f32):
%8 = arith.extf %in : f16 to f32
%9 = arith.extf %in_0 : f16 to f32
Expand Down
8 changes: 5 additions & 3 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_td_spec(
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
"""Generate a transform dialect spec module for the funcOp."""
"""Generate a transform dialect spec that applies the compilation info attr."""
pass


Expand Down Expand Up @@ -162,7 +162,8 @@ def get_td_spec(
M = acc_type.get_dim_size(0)
N = acc_type.get_dim_size(1)
K = lhs_type.get_dim_size(1)
func_name = f"match_mmt_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
# TODO(Max191): Get the function name from the func.func in the input module.
func_name = f"match_contraction_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
return build_td_spec(
ir_module.context, contraction_op, compilation_info, func_name
)
Expand Down Expand Up @@ -195,6 +196,7 @@ def get_td_spec(
Q = rhs_type.get_dim_size(1)
F = rhs_type.get_dim_size(3)
conv_type = conv_op.name.split(".")[-1]
# TODO(Max191): Get the function name from the func.func in the input module.
func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
return build_td_spec(
ir_module.context, conv_op, compilation_info, func_name
Expand Down Expand Up @@ -558,7 +560,7 @@ def get_default_output_dir() -> str:
return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M")


# TODO(Max191): Remove in favor of using tune_with_td.
# TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove in favor of using tune_with_td.
def tune(
input: str, # Path to the mlir file to be tuned
output: str = "", # Path to the output directory, auto creates one if not given
Expand Down
5 changes: 2 additions & 3 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
{root_op}
ins(%arg0, %arg1 : tensor<2048x2048xf16>, tensor<2048x2048xf16>)
outs(%1 : tensor<2048x2048xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
Expand Down Expand Up @@ -101,7 +102,6 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
matcher_sequence = None
entry_point = None
for op in named_sequence_ops:
print(op.opview.sym_name)
if str(op.opview.sym_name) == "\"apply_op_config\"":
apply_config_sequence = op
elif str(op.opview.sym_name) == "\"__kernel_config\"":
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None:
%cst = arith.constant 0 : i32
%0 = tensor.empty() : tensor<2x32x32x2048xi32>
%1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x2048xi32>) -> tensor<2x32x32x2048xi32>
%2 = linalg.conv_2d_nhwc_hwcf
%2 = linalg.conv_2d_nhwc_hwcf {root_op}
ins(%arg0, %arg1 : tensor<2x34x34x2048xi8>, tensor<3x3x2048x2048xi8>)
outs(%1 : tensor<2x32x32x2048xi32>) -> tensor<2x32x32x2048xi32>
return %2 : tensor<2x32x32x2048xi32>
Expand Down Expand Up @@ -182,7 +182,6 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None:
matcher_sequence = None
entry_point = None
for op in named_sequence_ops:
print(op.opview.sym_name)
if str(op.opview.sym_name) == "\"apply_op_config\"":
apply_config_sequence = op
elif str(op.opview.sym_name) == "\"__kernel_config\"":
Expand Down
25 changes: 6 additions & 19 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,6 @@ def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module:
return mlir_module


def find_root_op(
ir_module: ir.Module,
matcher: NamedOpMatcher,
) -> Optional[ir.Operation]:
func_ops: list[ir.Operation] = get_named_ops(ir_module, "func.func")
if len(func_ops) != 1:
return None
matched_ops = matcher.get_matched_ops(func_ops[0].operation)
if len(matched_ops) != 1:
return None
return matched_ops[0]


class DispatchParser(metaclass=ABCMeta):
@abstractmethod
def supports(self, op_name: str) -> bool:
Expand Down Expand Up @@ -118,14 +105,14 @@ def get_contraction_operation(
self,
ir_module: ir.Module,
) -> Optional[ir.Operation]:
return find_root_op(ir_module, ContractionOpInterfaceMatcher())
return match_root_op(ir_module, ContractionOpInterfaceMatcher())

# TODO(Max191): Pass the ir_module directly instead of the template str.
def get_shapes(self, template: list[str]) -> ProblemSize:
matcher = ContractionOpInterfaceMatcher()
with ir.Context() as ctx:
ir_module = ir.Module.parse("".join(template), ctx)
contraction_op = find_root_op(ir_module, matcher)
ir_module = ir.Module.parse("\n".join(template), ctx)
contraction_op = match_root_op(ir_module, matcher)
if contraction_op is None:
assert False, f"contraction op not found"
cdims = matcher.contraction_dimensions
Expand Down Expand Up @@ -167,13 +154,13 @@ def get_conv_operation(
self,
ir_module: ir.Module,
) -> Optional[ir.Operation]:
return find_root_op(ir_module, NamedOpMatcher(self.supported_ops))
return match_root_op(ir_module, NamedOpMatcher(self.supported_ops))

# TODO(Max191): Pass the ir_module directly instead of the template str.
def get_shapes(self, template: list[str]) -> ProblemSize:
with ir.Context() as ctx:
ir_module = ir.Module.parse("".join(template), ctx)
conv_op = find_root_op(ir_module, NamedOpMatcher(self.supported_ops))
ir_module = ir.Module.parse("\n".join(template), ctx)
conv_op = match_root_op(ir_module, NamedOpMatcher(self.supported_ops))
if conv_op is None:
assert False, f"convolution op not found"
lhs_type = ir.RankedTensorType(conv_op.operands[0].type)
Expand Down
3 changes: 2 additions & 1 deletion tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
{root_op}
ins(%arg0, %arg1 : tensor<4x4xf16>, tensor<4x4xf16>)
outs(%1 : tensor<4x4xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
Expand Down Expand Up @@ -82,7 +83,7 @@ def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None:
%cst = arith.constant 0 : i32
%0 = tensor.empty() : tensor<2x32x32x16xi32>
%1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32>
%2 = linalg.conv_2d_nhwc_hwcf
%2 = linalg.conv_2d_nhwc_hwcf {root_op}
ins(%arg0, %arg1 : tensor<2x34x34x16xi8>, tensor<3x3x16x16xi8>)
outs(%1 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32>
return %2 : tensor<2x32x32x16xi32>
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def generate_candidate_specs(
candidate_trackers.append(new_candidate)
except Exception as e:
logging.error("An error occurred during candidates generation: %s", str(e))
# Capture and log debug messages from candidate_gen.py
# Capture and log debug messages from candidate_gen.py.
tune_logger = logging.getLogger("tune_with_td")
for handler in logging.getLogger().handlers:
if isinstance(handler, logging.FileHandler):
Expand Down
72 changes: 41 additions & 31 deletions tuner/tuner/op_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@

# This code implements matcher functions for MLIR modules using python bindings.

from abc import abstractmethod
from abc import ABCMeta, abstractmethod

from .common import *
from iree.compiler import ir # type: ignore


class OpMatcher(metaclass=ABCMeta):
@abstractmethod
def match(self, op: ir.Operation) -> bool:
"""Check if the op passes the matching criteria."""
pass


def walk_collect_ops(
op: ir.Operation,
ops: list[ir.Operation],
Expand All @@ -22,15 +29,6 @@ def walk_collect_ops(
return ir.WalkResult.ADVANCE


def get_ops(op: ir.Operation, fn):
ops: list[ir.Operation] = []
op.opview.walk(
lambda op: walk_collect_ops(op, ops, fn),
ir.WalkOrder.POST_ORDER,
)
return ops


def get_ops_from_module(module: ir.Module, fn):
ops: list[ir.Operation] = []
for op in module.body.operations:
Expand All @@ -41,37 +39,32 @@ def get_ops_from_module(module: ir.Module, fn):
return ops


def get_named_ops(module: ir.Module, name: str):
return get_ops_from_module(module, lambda op: op.name == name)
def is_root_op(op: ir.Operation) -> bool:
for attr in op.opview.attributes:
if attr.name == "root_op":
return True
return False


def get_map_result_dim_positions(map: ir.AffineMap):
exprs = []
if not map.is_projected_permutation:
def match_root_op(
ir_module: ir.Module,
matcher: OpMatcher,
) -> Optional[ir.Operation]:
root_ops: list[ir.Operation] = get_ops_from_module(ir_module, is_root_op)
if len(root_ops) != 1:
return None
for expr in map.results:
dim_str = str(expr)
if len(dim_str) < 1:
return None
if dim_str[0] != "d":
return None
if not dim_str[1:].isdigit():
return None
dim_position = int(dim_str[1:])
exprs.append(dim_position)
return exprs
if not matcher.match(root_ops[0].operation):
return None
return root_ops[0]


class NamedOpMatcher:
class NamedOpMatcher(OpMatcher):
def __init__(self, op_names: list[str]):
self.op_names = op_names

def match(self, op: ir.Operation) -> bool:
return op.name in self.op_names

def get_matched_ops(self, op: ir.Operation):
return get_ops(op, lambda nestedOp: self.match(nestedOp))


# TODO(Max191): Add logic to match the body of the generic op.
class GenericOpMatcher(NamedOpMatcher):
Expand Down Expand Up @@ -111,6 +104,23 @@ def match(self, op: ir.Operation) -> bool:
return True


def get_map_result_dim_positions(map: ir.AffineMap):
exprs = []
if not map.is_projected_permutation:
return None
for expr in map.results:
dim_str = str(expr)
if len(dim_str) < 1:
return None
if dim_str[0] != "d":
return None
if not dim_str[1:].isdigit():
return None
dim_position = int(dim_str[1:])
exprs.append(dim_position)
return exprs


class ContractionOpInterfaceMatcher(GenericOpMatcher):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -152,7 +162,7 @@ def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool:
k_dims.append(d)
continue
return False

self.contraction_dimensions = ContractionDimensions(
batch=batch_dims,
m=m_dims,
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# in the code and runs it.

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_codegen # type: ignore
from iree.compiler.dialects import iree_codegen # type: ignore

from .common import *
from .dispatch_constraints import *
Expand Down

0 comments on commit 119833b

Please sign in to comment.