-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TKWave: Add support for workgroup constraints to determine grid size (#…
…34) This adds support for specifying the distribution of the input to workgroups. From the patch: ``` A constraint of the form `tkw.WorkgroupConstraint(M, BLOCK_M, 0)` specifies that we want to distribute dimension M along workgroup dim 0 with a tile size of BLOCK_M resulting in M // BLOCK_M workgroups along that dimension. This translates to an index constraint for all tensors of the shape [M, ?] -> index += (workgroup_id_0 * BLOCK_M, 0) ``` --------- Signed-off-by: Martin Lücke <[email protected]>
- Loading branch information
1 parent
67e337b
commit 8ac7aa0
Showing
6 changed files
with
139 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from ..ops.wave_ops import * | ||
from .constraints import * | ||
from .wave import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from ..lang import sym | ||
from .._support.indexing import IndexExpr | ||
|
||
|
||
@dataclass | ||
class Constraint(ABC): | ||
""" | ||
Base class for constraints. Every constraint reduces to | ||
the following form: | ||
Variables: [x0, x1, ...., xN] | ||
Bounds: [lb0 <= x0 <= ub0, ..., lbN <= xN <= ubN] | ||
Equality Constraints: [f0(x0, ..., xN) = 0, f1(x0, ..., xN) = 0, ...] | ||
Inequality Constraints: [g0(x0, ..., xN) <= 0, g1(x0, ..., xN) <= 0, ...] | ||
""" | ||
|
||
@abstractmethod | ||
def apply(self) -> IndexExpr: | ||
"""Apply the constraint and get the resulting index expression.""" | ||
... | ||
|
||
|
||
@dataclass | ||
class WorkgroupConstraint(Constraint): | ||
""" | ||
A constraint of the form `tkw.WorkgroupConstraint(M, BLOCK_M, 0)` | ||
specifies that we want to distribute dimension M along workgroup dim 0 | ||
with a tile size of BLOCK_M resulting in M // BLOCK_M workgroups along that | ||
dimension. This translates to an index constraint for all tensors of the | ||
shape [M, ?] -> index += (workgroup_id_0 * BLOCK_M, 0) | ||
""" | ||
|
||
dim: IndexExpr | ||
tile_size: IndexExpr | ||
workgroup_dim: int | ||
|
||
def apply(self) -> IndexExpr: | ||
match self.workgroup_dim: | ||
case 0: | ||
wg_dim = sym.WG0 | ||
case 1: | ||
wg_dim = sym.WG1 | ||
case _: | ||
raise ValueError("Invalid workgroup dimension. Expected 0 or 1.") | ||
return wg_dim * self.tile_size | ||
|
||
|
||
def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]: | ||
sorted_constraints = sorted(wg_constraints, key=lambda x: x.workgroup_dim) | ||
# Currently not more than one constraint in each dimension supported. | ||
if any( | ||
sorted_constraints[i].workgroup_dim == sorted_constraints[i + 1].workgroup_dim | ||
for i in range(len(sorted_constraints) - 1) | ||
): | ||
raise ValueError( | ||
"Multiple constraints in the same workgroup dimension are currently not supported." | ||
) | ||
grid: list[IndexExpr] = [ | ||
constraint.dim // constraint.tile_size for constraint in wg_constraints | ||
] | ||
return grid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import logging | ||
import pytest | ||
import unittest | ||
from shark_turbine.kernel.lang import sym | ||
from shark_turbine.kernel.wave.constraints import WorkgroupConstraint, get_grid_shape | ||
|
||
M = sym.M | ||
N = sym.N | ||
BLOCK_N = sym.BLOCK_N | ||
BLOCK_M = sym.BLOCK_K | ||
|
||
|
||
class ConstraintsTest(unittest.TestCase): | ||
def testWorkgroupConstraint(self): | ||
constraints: list[WorkgroupConstraint] = [WorkgroupConstraint(M, BLOCK_M, 0)] | ||
constraints.append(WorkgroupConstraint(N, BLOCK_N, 1)) | ||
|
||
assert get_grid_shape(constraints) == [M // BLOCK_M, N // BLOCK_N] | ||
|
||
# Checking multiple constraints in the same dimension not supported | ||
constraints += [WorkgroupConstraint(N, BLOCK_N, 1)] | ||
with pytest.raises( | ||
ValueError, | ||
match="Multiple constraints in the same workgroup dimension are currently not supported.", | ||
): | ||
get_grid_shape(constraints) | ||
|
||
# Checking invalid workgroup dimension | ||
with pytest.raises( | ||
ValueError, | ||
match="Invalid workgroup dimension. Expected 0 or 1.", | ||
): | ||
WorkgroupConstraint(N, BLOCK_N, 2).apply() | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.DEBUG) | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters