Skip to content

Commit

Permalink
Fix contraction matching bug, add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Dec 13, 2024
1 parent 2876528 commit 41bb86d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 29 deletions.
12 changes: 9 additions & 3 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,23 @@ def get_shapes(self, template: list[str]) -> ProblemSize:
assert False, f"contraction op not found"
cdims = matcher.contraction_dimensions
assert cdims, "no contraction dimensions"
assert matcher.lhs_dims, "no lhs dimensions"
assert matcher.rhs_dims, "no rhs dimensions"
assert matcher.res_dims, "no result dimensions"
assert len(cdims.batch) <= 1, f"must have at most 1 batch dimension"
assert len(cdims.m) == 1, f"must have a single m dimension"
assert len(cdims.n) == 1, f"must have a single n dimension"
assert len(cdims.k) == 1, f"must have a single k dimension"
lhs_type = ir.RankedTensorType(contraction_op.operands[0].type)
rhs_type = ir.RankedTensorType(contraction_op.operands[1].type)
res_type = ir.RankedTensorType(contraction_op.operands[2].type)
matmul_size = MatmulSize(
lhs_type.shape[0],
rhs_type.shape[0],
lhs_type.shape[1],
lhs_type.shape[matcher.lhs_dims.index(cdims.m[0])],
rhs_type.shape[matcher.rhs_dims.index(cdims.n[0])],
lhs_type.shape[matcher.lhs_dims.index(cdims.k[0])],
)
if len(cdims.batch) == 1:
matmul_size.B = lhs_type.shape[matcher.lhs_dims.index(cdims.batch[0])]
return ProblemSize(
matmul_size,
lhs_type=ShapedType(lhs_type.shape, lhs_type.element_type),
Expand Down
94 changes: 68 additions & 26 deletions tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,38 +41,80 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None:
)


CONTRACTION_TEMPLATE = r"""
builtin.module{{
func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : {res_type}
%1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type}
%2 = linalg.generic {{
indexing_maps = [
{lhs_map},
{rhs_map},
{res_map}],
iterator_types = {iterator_types}}}
{{root_op}}
ins(%arg0, %arg1 : {lhs_type}, {rhs_type})
outs(%1 : {res_type}) {{
^bb0(%in: f16, %in_0: f16, %out: f32):
%3 = arith.extf %in : f16 to f32
%4 = arith.extf %in_0 : f16 to f32
%5 = arith.mulf %3, %4 : f32
%6 = arith.addf %out, %5 : f32
linalg.yield %6 : f32
}} -> {res_type}
return %2 : {res_type}
}}
}}"""


def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx
module_str = """
builtin.module{
func.func @test(%arg0: tensor<4x4xf16>, %arg1: tensor<4x4xf16>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
{root_op}
ins(%arg0, %arg1 : tensor<4x4xf16>, tensor<4x4xf16>)
outs(%1 : tensor<4x4xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%3 = arith.extf %in : f16 to f32
%4 = arith.extf %in_0 : f16 to f32
%5 = arith.mulf %3, %4 : f32
%6 = arith.addf %out, %5 : f32
linalg.yield %6 : f32
} -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}"""
module = ir.Module.parse(module_str, context)

with ir.Location.unknown():
transpose_b_str = CONTRACTION_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([16, 64], ir.F16Type.get()),
rhs_type=ir.RankedTensorType.get([32, 64], ir.F16Type.get()),
res_type=ir.RankedTensorType.get([16, 32], ir.F32Type.get()),
lhs_map="affine_map<(d0, d1, d2) -> (d0, d2)>",
rhs_map="affine_map<(d0, d1, d2) -> (d1, d2)>",
res_map="affine_map<(d0, d1, d2) -> (d0, d1)>",
iterator_types='["parallel", "parallel", "reduction"]',
)
module = ir.Module.parse(transpose_b_str, context)
parser = dispatch_parser.ContractionOpInterfaceParser()
mmt_op = parser.get_contraction_operation(module)
assert mmt_op is not None
assert isinstance(mmt_op.opview, linalg.GenericOp)
shapes: common.ProblemSize = parser.get_shapes(transpose_b_str.splitlines())
assert shapes.matmul_size.B == 1
assert shapes.matmul_size.M == 16
assert shapes.matmul_size.N == 32
assert shapes.matmul_size.K == 64
assert shapes.lhs_type.shape == [16, 64]
assert isinstance(shapes.lhs_type.element_type, ir.F16Type)
assert shapes.rhs_type.shape == [32, 64]
assert isinstance(shapes.rhs_type.element_type, ir.F16Type)
assert shapes.res_type.shape == [16, 32]
assert isinstance(shapes.res_type.element_type, ir.F32Type)

with ir.Location.unknown():
bmm_transposed_inputs_str = CONTRACTION_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([5, 8, 128], ir.F16Type.get()),
rhs_type=ir.RankedTensorType.get([128, 40, 5], ir.F16Type.get()),
res_type=ir.RankedTensorType.get([5, 40, 8], ir.F32Type.get()),
lhs_map="affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>",
rhs_map="affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>",
res_map="affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>",
iterator_types='["parallel", "parallel", "parallel", "reduction"]',
)
module = ir.Module.parse(bmm_transposed_inputs_str, context)
mmt_op = parser.get_contraction_operation(module)
shapes = parser.get_shapes(bmm_transposed_inputs_str.splitlines())
assert shapes.matmul_size.B == 5
assert shapes.matmul_size.M == 8
assert shapes.matmul_size.N == 40
assert shapes.matmul_size.K == 128


def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None:
Expand Down
6 changes: 6 additions & 0 deletions tuner/tuner/op_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class ContractionOpInterfaceMatcher(GenericOpMatcher):
def __init__(self):
super().__init__()
self.contraction_dimensions: Optional[ContractionDimensions] = None
self.lhs_dims: Optional[list[int]] = None
self.rhs_dims: Optional[list[int]] = None
self.res_dims: Optional[list[int]] = None

def match_operands(self, operands: ir.OpOperandList) -> bool:
if len(operands) != 3:
Expand Down Expand Up @@ -169,4 +172,7 @@ def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool:
n=n_dims,
k=k_dims,
)
self.lhs_dims = lhs_dims
self.rhs_dims = rhs_dims
self.res_dims = res_dims
return True

0 comments on commit 41bb86d

Please sign in to comment.