Skip to content

Commit

Permalink
[TKW] Thread Shape analysis (#186)
Browse files Browse the repository at this point in the history
The motivation of this pass is to generalize the register analysis pass
which is used to determine the thread shape of TKW.Register, to all
other operations.

One main use case for such is to allow reduction, and later on
"broadcast" to use thread shape information from the kernel as opposed
to relying on vector_shape which may not always be valid.

We generalize the register analysis metho by finding a few anchor ops
who's thread shape information is determined, and then propagate to it's
successors and ancestors.

In addition to that we also implemented a couple helper
function/attributes.

1. Control_fn on BFS, ForwardSlice, BackwardSlice. This is to make it
easier for us to control/stop the search when we hit ops we do not want
to explore. In this case, we do not want to explore/propagate onto other
anchor ops and their children.

2. Introducing parent_op to IterArg and region of Reduction, for
developer ergonomics.

3. Move handling of IterArg and GetUser in BackwardSlice/BFS's get_input
exploration phase to be handled individually as opposed to being handled
when its' consumer is being explored. Previously to explore/propagate
IterArg/GetUser, we need to explore its' consumer, just exploring
IterArg/GetUser will not get handled correctly. This is useful for the
case where we want to propagate/explore mma.acc (usually IterArg)
directly.

---------

Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu authored Oct 3, 2024
1 parent 0f00c6d commit e0a8fdf
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 124 deletions.
8 changes: 8 additions & 0 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,13 @@ class IterArg(Placeholder):
a reduction node.
"""

def parent_op(self):
return get_custom(self.graph.parent_op)

def get_iter_idx(self):
src_reduction = self.parent_op()
return src_reduction.iter_args(self.graph).index(self.fx_node)


# Ops modeling TKW operations in the kernel language

Expand Down Expand Up @@ -847,6 +854,7 @@ def wrapper(f):
node._add_proxy_to_graph(graph)
node.fx_node.node.tkw_op = cls
node.fx_node.node.tkw_op_name = cls.tkw_op_name
graph.subgraphs[subgraph_name].parent_op = node.fx_node.node
return node.fx_node

return wrapper
Expand Down
4 changes: 2 additions & 2 deletions shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ def handle_register(emitter: WaveEmitter, node: fx.Node):
shape, dtype, value = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
if hasattr(node, "thread_shape"):
shape = [node.thread_shape]
get_thread_shape = lambda index: max(x.size for x in index.values())
shape = [get_thread_shape(get_custom(node).index)]
vector_shape = cast_py_literal(emitter, shape)
element_type = IrType.parse(dtype.ir_type_asm())
vector_type = VectorType.get(vector_shape, element_type)
Expand Down
11 changes: 7 additions & 4 deletions shark_turbine/kernel/wave/decompose_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
ShuffleOp,
CustomOp,
ExtractSlice,
Reduction,
)

from .utils import DCE
from .utils import DCE, subs_idxc
import torch.fx as fx
import math
from typing import Callable
Expand Down Expand Up @@ -103,9 +104,11 @@ def decompose_reduce_ops(
raise NotImplementedError(
"Only implemented reduction on fastest dimension."
)
reduction_block_size = constraint_tile_size[reduction_dim]
reduction_size = reduction_block_size.subs(index_map)
local_reduction_size = reduction_size / subgroup_size

get_thread_shape = lambda index: max(
subs_idxc(x.size) for x in index.values()
)
local_reduction_size = get_thread_shape(get_custom(custom.arg).index)
local_reduction = emit_local_reduction(
binary_fn, reduction_src, custom.graph, local_reduction_size
)
Expand Down
93 changes: 0 additions & 93 deletions shark_turbine/kernel/wave/register_analysis.py

This file was deleted.

142 changes: 142 additions & 0 deletions shark_turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ...support.logging import get_logger
from shark_turbine.kernel._support.tracing import CapturedTrace
import torch.fx as fx
from ..ops.wave_ops import *
from ..lang.global_symbols import *
from .utils import capture_forward_slice, capture_backward_slice

logger = get_logger("turbine.wave.thread_shape_analysis")


@dataclass(order=True)
class DimSize:
dim: IndexSymbol
size: int

def __hash__(self):
return hash((self.dim, self.size))


def get_dim_sizes(indices: list[IndexSequence]):
dims = frozenset([DimSize(dim, seq.size) for dim, seq in indices.items()])
return dims


def get_custom_dim_sizes(custom: CustomOp):
return get_dim_sizes(custom.index)


def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
for target in target_dim_sizes:
if target.dim not in custom.index:
raise NotImplementedError(
"NYI: Handle when source target index size is not found in target/user index."
)
custom.index[target.dim].size = target.size


def determine_thread_shapes(trace: CapturedTrace):
"""
This function does analysis and propagation of thread shape. It does by such:
1. Look for "anchor" ops who has information of it's elem_per_thread.
2. Do a forward/backward slice on these anchor ops to get ops that
who's shapes depends on these anchor ops.
3. We bucket these ops to Variadic(Index->elem_per_thread) mapping.
4. At every bucket of (index -> elem_per_thread), we apply these information
by updating their indexSequence size.
We stored the buckets above in a variable/dict called `thread_size_to_ops`.
`thread_size_to_ops` is a dict that uses thread_shapes as key and for every
key/thread_shape will map to a set of fx.nodes that needs to have that
thread_shape in it's indexSequence.
`thread_shapes` is used to store thread_size at every dimension that the op
cares about. We use a frozenset[DimSize] to represent it, where DimSize
is essentially a pair<dimension: IndexSymbol, thread_size: int>. we are using
frozen_set since we do not care about the order of dims for the shape/size
propagation.
We use sets[CustomOp] to represent the values of `thread_size_ops` S.T we can
easily find any conflicting of index using set operations and handle/resolve it
if required.
For better illustration, here's an example:
Kernel:
imm = tkw.mul(x, y)
lhs = tkw.neg(imm)
a = tkw.mma(lhs, rhs, acc)
b = tkw.exp2(a)
Anchors:
mma.lhs: {IndexSize(index=M, size=1), IndexSize(index=K, size=4)}
mma.rhs: {IndexSize(index=K, size=4), IndexSize(index=N, size=1)}
mma.acc: {IndexSize(index=M, size=4), IndexSize(index=N, size=1)}
Bucket Entry:
thread_sizes_to_ops[frozenset({IndexSize(index=M, size=1), IndexSize(index=K, size=4)}] = set(lhs, imm, x, y)
thread_sizes_to_ops[frozenset({IndexSize(index=M, size=4), IndexSize(index=N, size=1)}] = set(acc, exp2_0)
thread_sizes_to_ops[frozenset({IndexSize(index=K, size=4), IndexSize(index=N, size=1)}] = set(rhs, ...)
"""

# Anchor ops are ops who's thread shape are predetermined.
anchorOpTypes = (Read, Write, MMA, ReduceOp)
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
nonPropagatableTypes = anchorOpTypes + noHandleTypes

def is_anchor_op(node: fx.Node):
return isinstance(get_custom(node), anchorOpTypes)

def propagatable_op(node: fx.Node):
return not isinstance(get_custom(node), nonPropagatableTypes)

anchor_ops = trace.walk(is_anchor_op)
thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {}
for anchor_op in anchor_ops:
custom = get_custom(anchor_op)
index_sizes = get_custom_dim_sizes(custom)
if isinstance(custom, (Read, ReduceOp)):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
index_sizes, set([])
).union(fwd_slice)
elif isinstance(custom, Write):
bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
index_sizes, set([])
).union(bwd_slice)
elif isinstance(custom, MMA):
lhs_bwd_slice = capture_backward_slice(custom.lhs, propagatable_op)
rhs_bwd_slice = capture_backward_slice(custom.rhs, propagatable_op)
acc_slice = capture_forward_slice(custom.acc, propagatable_op)
acc_slice = acc_slice.union(
capture_backward_slice(custom.acc, propagatable_op)
)
acc_index = get_dim_sizes(custom.acc_index)
lhs_index = get_dim_sizes(custom.lhs_index)
rhs_index = get_dim_sizes(custom.rhs_index)
thread_size_to_ops[acc_index] = thread_size_to_ops.get(
acc_index, set([])
).union(acc_slice)
thread_size_to_ops[lhs_index] = thread_size_to_ops.get(
lhs_index, set([])
).union(lhs_bwd_slice)
thread_size_to_ops[rhs_index] = thread_size_to_ops.get(
rhs_index, set([])
).union(rhs_bwd_slice)

# Go through each index-size buckets, and apply the index-size to ops in the bucket.
cummulative_set = set()
for target_index_size, target_ops in thread_size_to_ops.items():
# Ensure that we do not have any conflicts.
if not cummulative_set.isdisjoint(target_ops):
raise NotImplementedError("NYI: Handling of conflicting thread shape.")
cummulative_set = cummulative_set.union(target_ops)
for user in target_ops:
custom_user = get_custom(user)
set_index_size(custom_user, target_index_size)
51 changes: 30 additions & 21 deletions shark_turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,28 +494,33 @@ def get_inputs(
Return the inputs of a node, propagating through reductions.
"""
inputs = []
for input in node.all_input_nodes:
custom = get_custom(input)
if isinstance(custom, GetResult):
reduction = custom.value
assert isinstance(
reduction, Reduction
), "GetResult must be used by a Reduction"
# Map get result to output
inputs.append(reduction.outputs[custom.res_idx])
continue
if isinstance(custom, IterArg):
# Map iter args to init args
iter_arg_idx = reduction.iter_args.index(node)
inputs.append(reduction.init_args[iter_arg_idx])
continue
inputs.append(input)
custom = get_custom(node)
if isinstance(custom, IterArg):
# Map iter args to init args
local_reduction = reduction
if reduction is None:
local_reduction = custom.parent_op()
iter_arg_idx = custom.get_iter_idx()
inputs.append(local_reduction.init_args[iter_arg_idx])
elif isinstance(custom, GetResult):
reduction = get_custom(custom.value)
assert isinstance(
get_custom(reduction), Reduction
), "GetResult must be used by a Reduction"
# Map get result to output
reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name]
inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx])
else:
# Default handling for other ops.
for input in node.all_input_nodes:
inputs.append(input)
return inputs, reduction


def bfs(
node: fx.Node,
get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]],
filter_fn: Callable[[fx.node], bool],
) -> set[fx.Node]:
"""
Run BFS on the graph to capture the forward slice of a node.
Expand All @@ -529,25 +534,29 @@ def bfs(
s = queue.pop(0)
neighbors, reduction = get_neighbors(s, reduction)
for neighbor in neighbors:
if neighbor not in visited:
if neighbor not in visited and filter_fn(neighbor):
visited.add(neighbor)
queue.append(neighbor)
return visited


def capture_forward_slice(node: fx.Node) -> set[fx.Node]:
def capture_forward_slice(
node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True
) -> set[fx.Node]:
"""
Run BFS on the graph to capture the forward slice of a node.
"""
return bfs(node, lambda x, y: get_users(x, y))
return bfs(node, lambda x, y: get_users(x, y), filter_fn)


def capture_backward_slice(node: fx.Node) -> set[fx.Node]:
def capture_backward_slice(
node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True
) -> set[fx.Node]:
"""
Capture backward slice from a node and return the tree.
Assumes graph is directed.
"""
return bfs(node, lambda x, y: get_inputs(x, y))
return bfs(node, lambda x, y: get_inputs(x, y), filter_fn)


def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]:
Expand Down
Loading

0 comments on commit e0a8fdf

Please sign in to comment.