Skip to content

Commit

Permalink
[tuner]: use compilation_info binding (#678)
Browse files Browse the repository at this point in the history
This PR is relevant to the task in
#453 : use IREE bindings for
compilation info (incl., lowering_config and translation_info).

Retire data class `configuration` and use the `compilation_info` from
IREE python binding.

Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu authored Dec 12, 2024
1 parent ffb870f commit d279aff
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 129 deletions.
105 changes: 50 additions & 55 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,23 @@

def apply_configuration(
template: list[str],
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> str:
lowering_config = configuration.lowering_config
lowering_config = compilation_info.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn
workgroup_sizes = lowering_config.workgroup_tile_sizes
reduction_sizes = lowering_config.reduction_tile_sizes
gpu_pipeline_options = configuration.translation_info.configuration[
gpu_pipeline_options = compilation_info.translation_info.configuration[
GPU_PIPELINE_OPTIONS_KEY
]
waves_per_eu = configuration.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][
waves_per_eu = compilation_info.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][
WAVES_PER_EU_KEY
]
tune_logger.info(f"Applying: {configuration}")
tune_logger.info(f"Applying: {compilation_info}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
)
Expand All @@ -69,7 +69,7 @@ def apply_configuration(
expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = {intrinsic}, subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.translation_info.workgroup_size))}] subgroup_size = {configuration.translation_info.subgroup_size},'
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, compilation_info.translation_info.workgroup_size))}] subgroup_size = {compilation_info.translation_info.subgroup_size},'
repl2 = f"workgroup = {workgroup_sizes}"
repl3 = f"reduction = {reduction_sizes}"
repl4 = f"gpu_pipeline_options = {gpu_pipeline_options}"
Expand Down Expand Up @@ -101,7 +101,7 @@ def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
"""Apply parameter transformations to the operation."""
pass
Expand Down Expand Up @@ -132,7 +132,10 @@ def find_handler(self, op_name: str) -> DispatchTuner:

class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
self,
problem_size: ProblemSize,
functionName: str,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> str:
return f"""
transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
Expand All @@ -141,10 +144,7 @@ def get_transform_function_mmt(
%rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
%config = transform.param.constant {compilation_info} -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
}}
"""
Expand All @@ -153,29 +153,34 @@ def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
M, N, K = problem_size.MNK
modified = indent(
self.get_transform_function_mmt(
problem_size, f"match_mmt_{M}x{N}x{K}", configuration
problem_size, f"match_mmt_{M}x{N}x{K}", compilation_info
),
"// ",
)
modified += apply_configuration(
template,
configuration,
compilation_info,
)
embeddable = indent(
self.get_transform_function_mmt(problem_size, f"match_op", configuration),
self.get_transform_function_mmt(
problem_size, f"match_op", compilation_info
),
" ",
)
return MLIRTransformation(template, modified, embeddable)


class ConvTuner(DispatchTuner, ConvParser):
def get_transform_function_conv(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
self,
problem_size: ProblemSize,
functionName: str,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> str:
dynamic_batch_input_ty = problem_size.lhs_type
dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy()
Expand All @@ -198,10 +203,7 @@ def get_transform_function_conv(
ins(%lhs, %rhs : {input}, {filter})
outs(%out : {output}) -> {output}
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
%config = transform.param.constant {compilation_info} -> !transform.any_param
transform.yield %conv, %config : !transform.any_op, !transform.any_param
}}
"""
Expand All @@ -210,23 +212,25 @@ def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
conv_dims = ConvDimInfo.from_problem_size(problem_size)
modified = indent(
self.get_transform_function_conv(
problem_size,
f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}",
configuration,
compilation_info,
),
"// ",
)
modified += apply_configuration(
template,
configuration,
compilation_info,
)
embeddable = indent(
self.get_transform_function_conv(problem_size, f"match_op", configuration),
self.get_transform_function_conv(
problem_size, f"match_op", compilation_info
),
" ",
)
return MLIRTransformation(template, modified, embeddable)
Expand All @@ -237,7 +241,7 @@ def get_transform_function_broadcast_rhs_mmt(
self,
problem_size: ProblemSize,
functionName: str,
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> str:
lhs_dynamic_batch = problem_size.lhs_type
lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy()
Expand All @@ -250,10 +254,7 @@ def get_transform_function_broadcast_rhs_mmt(
%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
%config = transform.param.constant {compilation_info} -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
}}
"""
Expand All @@ -262,23 +263,23 @@ def apply_params_broadcast_rhs_mmt(
self,
problem_size: ProblemSize,
template: list[str],
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
M, N, K = problem_size.MNK
modified = indent(
self.get_transform_function_broadcast_rhs_mmt(
problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration
problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", compilation_info
),
"// ",
)
modified += apply_configuration(
template,
configuration,
compilation_info,
)

embeddable = indent(
self.get_transform_function_broadcast_rhs_mmt(
problem_size, f"match_op", configuration
problem_size, f"match_op", compilation_info
),
" ",
)
Expand All @@ -288,19 +289,19 @@ def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
if self.is_broadcast_rhs_mmt(template):
return self.apply_params_broadcast_rhs_mmt(
problem_size, template, configuration
problem_size, template, compilation_info
)

# TODO: Generate transform function.
return MLIRTransformation(
template,
apply_configuration(
template,
configuration,
compilation_info,
),
"",
)
Expand All @@ -311,7 +312,7 @@ def get_transform_function_batch_mmt(
self,
problem_size: ProblemSize,
functionName: str,
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> str:
return f"""
transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
Expand All @@ -320,10 +321,7 @@ def get_transform_function_batch_mmt(
%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
%config = transform.param.constant {compilation_info} -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
}}
"""
Expand All @@ -332,24 +330,24 @@ def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
M, N, K = problem_size.MNK
B = problem_size.matmul_size.B
modified = indent(
self.get_transform_function_batch_mmt(
problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration
problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", compilation_info
),
"// ",
)
modified += apply_configuration(
template,
configuration,
compilation_info,
)

embeddable = indent(
self.get_transform_function_batch_mmt(
problem_size, f"match_op", configuration
problem_size, f"match_op", compilation_info
),
" ",
)
Expand All @@ -362,7 +360,7 @@ def get_transform_function_batch_matmul(
problem_size: ProblemSize,
tile_dims: str,
functionName: str,
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> str:
input0 = f"tensor<{problem_size.lhs_type}>"
input1 = f"tensor<{problem_size.rhs_type}>"
Expand All @@ -377,10 +375,7 @@ def get_transform_function_batch_matmul(
ins(%lhs, %rhs : {input0}, {input1})
outs(%out : {output}) -> {output}
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
%config = transform.param.constant {compilation_info} -> !transform.any_param
transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param
}}
"""
Expand All @@ -389,26 +384,26 @@ def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
configuration: Configuration,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
M, N, K = problem_size.MNK
modified = indent(
self.get_transform_function_batch_matmul(
problem_size,
self.tile_dims,
f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}",
configuration,
compilation_info,
),
"// ",
)
modified += apply_configuration(
template,
configuration,
compilation_info,
)

embeddable = indent(
self.get_transform_function_batch_matmul(
problem_size, self.tile_dims, f"match_op", configuration
problem_size, self.tile_dims, f"match_op", compilation_info
),
" ",
)
Expand Down
Loading

0 comments on commit d279aff

Please sign in to comment.