Skip to content

Commit

Permalink
TKWave: Add support for workgroup constraints to determine grid size (#…
Browse files Browse the repository at this point in the history
…34)

This adds support for specifying the distribution of the input to
workgroups.

From the patch:
```
A constraint of the form `tkw.WorkgroupConstraint(M, BLOCK_M, 0)`
specifies that we want to distribute dimension M along workgroup dim 0
with a tile size of BLOCK_M resulting in M // BLOCK_M workgroups along that
dimension. This translates to an index constraint for all tensors of the
shape [M, ?] -> index += (workgroup_id_0 * BLOCK_M, 0)
```

---------

Signed-off-by: Martin Lücke <[email protected]>
  • Loading branch information
martin-luecke authored Jun 26, 2024
1 parent 67e337b commit 8ac7aa0
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 13 deletions.
6 changes: 3 additions & 3 deletions shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class NotSetType:
# These are just light-weight helpers around sympy symbols and expressions.
###############################################################################

IndexSymbol = sympy.core.Symbol
IndexExpr = sympy.core.Expr
IndexSymbol = sympy.Symbol
IndexExpr = sympy.Expr


def index_symbol(name: str) -> IndexSymbol:
Expand All @@ -53,7 +53,7 @@ def index_expr(value: Any) -> IndexExpr:


class _IndexSymbolExpando:
def __getattr__(self, n):
def __getattr__(self, n) -> IndexSymbol:
return index_symbol(n)


Expand Down
1 change: 1 addition & 0 deletions shark_turbine/kernel/wave/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ..ops.wave_ops import *
from .constraints import *
from .wave import *
62 changes: 62 additions & 0 deletions shark_turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from ..lang import sym
from .._support.indexing import IndexExpr


@dataclass
class Constraint(ABC):
"""
Base class for constraints. Every constraint reduces to
the following form:
Variables: [x0, x1, ...., xN]
Bounds: [lb0 <= x0 <= ub0, ..., lbN <= xN <= ubN]
Equality Constraints: [f0(x0, ..., xN) = 0, f1(x0, ..., xN) = 0, ...]
Inequality Constraints: [g0(x0, ..., xN) <= 0, g1(x0, ..., xN) <= 0, ...]
"""

@abstractmethod
def apply(self) -> IndexExpr:
"""Apply the constraint and get the resulting index expression."""
...


@dataclass
class WorkgroupConstraint(Constraint):
"""
A constraint of the form `tkw.WorkgroupConstraint(M, BLOCK_M, 0)`
specifies that we want to distribute dimension M along workgroup dim 0
with a tile size of BLOCK_M resulting in M // BLOCK_M workgroups along that
dimension. This translates to an index constraint for all tensors of the
shape [M, ?] -> index += (workgroup_id_0 * BLOCK_M, 0)
"""

dim: IndexExpr
tile_size: IndexExpr
workgroup_dim: int

def apply(self) -> IndexExpr:
match self.workgroup_dim:
case 0:
wg_dim = sym.WG0
case 1:
wg_dim = sym.WG1
case _:
raise ValueError("Invalid workgroup dimension. Expected 0 or 1.")
return wg_dim * self.tile_size


def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]:
sorted_constraints = sorted(wg_constraints, key=lambda x: x.workgroup_dim)
# Currently not more than one constraint in each dimension supported.
if any(
sorted_constraints[i].workgroup_dim == sorted_constraints[i + 1].workgroup_dim
for i in range(len(sorted_constraints) - 1)
):
raise ValueError(
"Multiple constraints in the same workgroup dimension are currently not supported."
)
grid: list[IndexExpr] = [
constraint.dim // constraint.tile_size for constraint in wg_constraints
]
return grid
29 changes: 22 additions & 7 deletions shark_turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Any, Callable, Optional, Type
from typing import Any, Callable, Optional
import inspect
import os

from ..compiler import builder, dispatch_codegen, kernel_codegen
from ..compiler.ir import Context, Operation
from .codegen import WaveEmitter
from .constraints import (
Constraint,
WorkgroupConstraint,
get_grid_shape,
)
from ..lang import Grid
from ..ops import wave_ops
from .._support.tracing import (
Expand All @@ -17,16 +21,16 @@
__all__ = ["wave", "wave_trace_only"]


def wave():
def wave(constraints: Optional[list[Constraint]] = None):
def decorator(f: Callable[..., Any]) -> "LaunchableWave":
return LaunchableWave(f.__name__, f)
return LaunchableWave(constraints, f.__name__, f)

return decorator


def wave_trace_only():
def wave_trace_only(constraints: Optional[list[Constraint]] = None):
def decorator(f: Callable[..., Any]) -> "Callable[[], CapturedTrace]":
wave = LaunchableWave(f.__name__, f)
wave = LaunchableWave(constraints, f.__name__, f)
return wave._trace # type: ignore

return decorator
Expand All @@ -35,16 +39,27 @@ def decorator(f: Callable[..., Any]) -> "Callable[[], CapturedTrace]":
class LaunchableWave(Launchable):
def __init__(
self,
constraints: Optional[list[Constraint]],
name: str,
eager_function: Callable[[Any], Any],
):
super().__init__(eager_function)

self.grid_type = Grid[None, None]
self.constraints = constraints if constraints else []
self._name = name
self._f = eager_function
self._sig = inspect.signature(eager_function)

self.grid_type = Grid[tuple(get_grid_shape(self.workgroup_constraints))]

@property
def workgroup_constraints(self) -> list[WorkgroupConstraint]:
return [
constraint
for constraint in self.constraints
if isinstance(constraint, WorkgroupConstraint)
]

def _trace(self) -> CapturedTrace:
region_graph = KernelRegionGraph()
with CompiledContext(region_graph, grid_type=self.grid_type) as context:
Expand Down
38 changes: 38 additions & 0 deletions tests/kernel/wave/constraints_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging
import pytest
import unittest
from shark_turbine.kernel.lang import sym
from shark_turbine.kernel.wave.constraints import WorkgroupConstraint, get_grid_shape

M = sym.M
N = sym.N
BLOCK_N = sym.BLOCK_N
BLOCK_M = sym.BLOCK_K


class ConstraintsTest(unittest.TestCase):
def testWorkgroupConstraint(self):
constraints: list[WorkgroupConstraint] = [WorkgroupConstraint(M, BLOCK_M, 0)]
constraints.append(WorkgroupConstraint(N, BLOCK_N, 1))

assert get_grid_shape(constraints) == [M // BLOCK_M, N // BLOCK_N]

# Checking multiple constraints in the same dimension not supported
constraints += [WorkgroupConstraint(N, BLOCK_N, 1)]
with pytest.raises(
ValueError,
match="Multiple constraints in the same workgroup dimension are currently not supported.",
):
get_grid_shape(constraints)

# Checking invalid workgroup dimension
with pytest.raises(
ValueError,
match="Invalid workgroup dimension. Expected 0 or 1.",
):
WorkgroupConstraint(N, BLOCK_N, 2).apply()


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
16 changes: 13 additions & 3 deletions tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# RUN: python %s

import logging
import pytest
import torch
Expand All @@ -26,13 +24,17 @@ def testGemm(self):
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

# Wave-level micro-kernel.
# Since warps are not directly addressable, there is no
# explicit notion of a warp id (like a workgroup or thread id).
# This kernel uses the input sizes M, N, K throughout, as the tiling
# and data movement strategy is determined during the compilation process.
# These can be influenced by introducing constraints.
@tkw.wave()
@tkw.wave(constraints)
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
Expand Down Expand Up @@ -75,6 +77,14 @@ def repeat(acc: tkl.Register) -> tkl.Register[M, N, tkl.f32]:
c = torch.zeros(64, 128, dtype=torch.float32)
gemm(a, b, c)

# TODO: Note this is currently not triggered as the stub exception
# is raised first. Remove this note when this successfully runs
# through codegen.
assert gemm.grid_type.symbolic_shape == (
M // BLOCK_M,
N // BLOCK_N,
)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down

0 comments on commit 8ac7aa0

Please sign in to comment.