diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index f1292e4d..2373feec 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -651,6 +651,13 @@ class IterArg(Placeholder): a reduction node. """ + def parent_op(self): + return get_custom(self.graph.parent_op) + + def get_iter_idx(self): + src_reduction = self.parent_op() + return src_reduction.iter_args(self.graph).index(self.fx_node) + # Ops modeling TKW operations in the kernel language @@ -847,6 +854,7 @@ def wrapper(f): node._add_proxy_to_graph(graph) node.fx_node.node.tkw_op = cls node.fx_node.node.tkw_op_name = cls.tkw_op_name + graph.subgraphs[subgraph_name].parent_op = node.fx_node.node return node.fx_node return wrapper diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index adcd69b8..313e72cb 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -436,8 +436,8 @@ def handle_register(emitter: WaveEmitter, node: fx.Node): shape, dtype, value = node.args except ValueError as e: raise ValidationError("Malformed arguments") from e - if hasattr(node, "thread_shape"): - shape = [node.thread_shape] + get_thread_shape = lambda index: max(x.size for x in index.values()) + shape = [get_thread_shape(get_custom(node).index)] vector_shape = cast_py_literal(emitter, shape) element_type = IrType.parse(dtype.ir_type_asm()) vector_type = VectorType.get(vector_shape, element_type) diff --git a/shark_turbine/kernel/wave/decompose_reduce_ops.py b/shark_turbine/kernel/wave/decompose_reduce_ops.py index 1dac06cc..9916bb50 100644 --- a/shark_turbine/kernel/wave/decompose_reduce_ops.py +++ b/shark_turbine/kernel/wave/decompose_reduce_ops.py @@ -20,9 +20,10 @@ ShuffleOp, CustomOp, ExtractSlice, + Reduction, ) -from .utils import DCE +from .utils import DCE, subs_idxc import torch.fx as fx import math from typing import Callable @@ -103,9 +104,11 @@ def decompose_reduce_ops( raise NotImplementedError( "Only implemented reduction on fastest dimension." ) - reduction_block_size = constraint_tile_size[reduction_dim] - reduction_size = reduction_block_size.subs(index_map) - local_reduction_size = reduction_size / subgroup_size + + get_thread_shape = lambda index: max( + subs_idxc(x.size) for x in index.values() + ) + local_reduction_size = get_thread_shape(get_custom(custom.arg).index) local_reduction = emit_local_reduction( binary_fn, reduction_src, custom.graph, local_reduction_size ) diff --git a/shark_turbine/kernel/wave/register_analysis.py b/shark_turbine/kernel/wave/register_analysis.py deleted file mode 100644 index cbd42fbe..00000000 --- a/shark_turbine/kernel/wave/register_analysis.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from ..wave.constraints import Constraint -from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr -from .._support.tracing import CapturedTrace -from ...support.logging import get_logger -from ..ops.wave_ops import get_custom, NewRegister, CustomOp, MMA, Reduction, ReduceOp -from .utils import get_hardware_vector_map -import torch.fx as fx - -logger = get_logger("turbine.wave.register_analysis") - - -def set_register_shape( - trace: CapturedTrace, custom: CustomOp, vector_map: dict[IndexSymbol, int] -) -> None: - for custom_user in custom.users: - if isinstance(custom_user, MMA): - arg_index = custom_user.fx_node.args.index(custom.fx_node) - get_thread_shape = lambda index: max(x.size for x in index.values()) - match arg_index: - case 0: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.lhs_index - ) - case 1: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.rhs_index - ) - case 2: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.acc_index - ) - break - - elif isinstance(custom_user, Reduction): - idx = custom_user.init_args.index(custom.fx_node) - iter_arg = get_custom( - custom_user.iter_args(trace.get_subgraph(custom_user.subgraph_name))[ - idx - ] - ) - set_register_shape(trace, iter_arg, vector_map) - custom.fx_node.thread_shape = iter_arg.fx_node.thread_shape - break - elif isinstance(custom_user, ReduceOp): - # Check that dim is non-reduction && in hw_constraint.vector_shape. - is_parallel_dim = lambda dim: dim != custom_user.dim and dim in vector_map - # TODO: Modify num_reduction_dims once we add support for multi-dim reduction. - num_reduction_dims = 1 - register_shape = [ - vector_map[dim] - for dim in custom_user.type.symbolic_shape - if is_parallel_dim(dim) - ] - expected_result_rank = ( - len(custom_user.type.symbolic_shape) - custom_user.num_reduction_dims - ) - # If rank do not match => some dims not found in hw_constraint.vector_shape. - if len(register_shape) != expected_result_rank: - raise NotImplementedError( - "NYI: Handling of dim not in vector_shapes during register analysis." - ) - non_unit_dims = sum(1 for dim in register_shape if dim > 1) - if non_unit_dims > 1: - raise NotImplementedError( - "NYI: Currently Register semantic only support 0-D vector." - ) - custom.fx_node.thread_shape = max(register_shape) - else: - raise NotImplementedError( - f"Register shape propagation not implemented for {custom_user}" - ) - - -def determine_register_shape( - trace: CapturedTrace | fx.Graph, constraints: list[Constraint] -) -> None: - """ - Each register op is annotated with the wave shape of the register. This - function determines the thread shape of the register based on the uses - of the register in the graph. - """ - register_nodes = trace.walk(lambda node: isinstance(get_custom(node), NewRegister)) - if not register_nodes: - return - vector_map = get_hardware_vector_map(constraints) - for node in register_nodes: - set_register_shape(trace, get_custom(node), vector_map) diff --git a/shark_turbine/kernel/wave/thread_shape_analysis.py b/shark_turbine/kernel/wave/thread_shape_analysis.py new file mode 100644 index 00000000..572561fb --- /dev/null +++ b/shark_turbine/kernel/wave/thread_shape_analysis.py @@ -0,0 +1,142 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...support.logging import get_logger +from shark_turbine.kernel._support.tracing import CapturedTrace +import torch.fx as fx +from ..ops.wave_ops import * +from ..lang.global_symbols import * +from .utils import capture_forward_slice, capture_backward_slice + +logger = get_logger("turbine.wave.thread_shape_analysis") + + +@dataclass(order=True) +class DimSize: + dim: IndexSymbol + size: int + + def __hash__(self): + return hash((self.dim, self.size)) + + +def get_dim_sizes(indices: list[IndexSequence]): + dims = frozenset([DimSize(dim, seq.size) for dim, seq in indices.items()]) + return dims + + +def get_custom_dim_sizes(custom: CustomOp): + return get_dim_sizes(custom.index) + + +def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): + for target in target_dim_sizes: + if target.dim not in custom.index: + raise NotImplementedError( + "NYI: Handle when source target index size is not found in target/user index." + ) + custom.index[target.dim].size = target.size + + +def determine_thread_shapes(trace: CapturedTrace): + """ + This function does analysis and propagation of thread shape. It does by such: + 1. Look for "anchor" ops who has information of it's elem_per_thread. + 2. Do a forward/backward slice on these anchor ops to get ops that + who's shapes depends on these anchor ops. + 3. We bucket these ops to Variadic(Index->elem_per_thread) mapping. + 4. At every bucket of (index -> elem_per_thread), we apply these information + by updating their indexSequence size. + + We stored the buckets above in a variable/dict called `thread_size_to_ops`. + + `thread_size_to_ops` is a dict that uses thread_shapes as key and for every + key/thread_shape will map to a set of fx.nodes that needs to have that + thread_shape in it's indexSequence. + + `thread_shapes` is used to store thread_size at every dimension that the op + cares about. We use a frozenset[DimSize] to represent it, where DimSize + is essentially a pair. we are using + frozen_set since we do not care about the order of dims for the shape/size + propagation. + + We use sets[CustomOp] to represent the values of `thread_size_ops` S.T we can + easily find any conflicting of index using set operations and handle/resolve it + if required. + + For better illustration, here's an example: + Kernel: + imm = tkw.mul(x, y) + lhs = tkw.neg(imm) + a = tkw.mma(lhs, rhs, acc) + b = tkw.exp2(a) + Anchors: + mma.lhs: {IndexSize(index=M, size=1), IndexSize(index=K, size=4)} + mma.rhs: {IndexSize(index=K, size=4), IndexSize(index=N, size=1)} + mma.acc: {IndexSize(index=M, size=4), IndexSize(index=N, size=1)} + Bucket Entry: + thread_sizes_to_ops[frozenset({IndexSize(index=M, size=1), IndexSize(index=K, size=4)}] = set(lhs, imm, x, y) + thread_sizes_to_ops[frozenset({IndexSize(index=M, size=4), IndexSize(index=N, size=1)}] = set(acc, exp2_0) + thread_sizes_to_ops[frozenset({IndexSize(index=K, size=4), IndexSize(index=N, size=1)}] = set(rhs, ...) + + """ + + # Anchor ops are ops who's thread shape are predetermined. + anchorOpTypes = (Read, Write, MMA, ReduceOp) + noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) + nonPropagatableTypes = anchorOpTypes + noHandleTypes + + def is_anchor_op(node: fx.Node): + return isinstance(get_custom(node), anchorOpTypes) + + def propagatable_op(node: fx.Node): + return not isinstance(get_custom(node), nonPropagatableTypes) + + anchor_ops = trace.walk(is_anchor_op) + thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {} + for anchor_op in anchor_ops: + custom = get_custom(anchor_op) + index_sizes = get_custom_dim_sizes(custom) + if isinstance(custom, (Read, ReduceOp)): + fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) + thread_size_to_ops[index_sizes] = thread_size_to_ops.get( + index_sizes, set([]) + ).union(fwd_slice) + elif isinstance(custom, Write): + bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op) + thread_size_to_ops[index_sizes] = thread_size_to_ops.get( + index_sizes, set([]) + ).union(bwd_slice) + elif isinstance(custom, MMA): + lhs_bwd_slice = capture_backward_slice(custom.lhs, propagatable_op) + rhs_bwd_slice = capture_backward_slice(custom.rhs, propagatable_op) + acc_slice = capture_forward_slice(custom.acc, propagatable_op) + acc_slice = acc_slice.union( + capture_backward_slice(custom.acc, propagatable_op) + ) + acc_index = get_dim_sizes(custom.acc_index) + lhs_index = get_dim_sizes(custom.lhs_index) + rhs_index = get_dim_sizes(custom.rhs_index) + thread_size_to_ops[acc_index] = thread_size_to_ops.get( + acc_index, set([]) + ).union(acc_slice) + thread_size_to_ops[lhs_index] = thread_size_to_ops.get( + lhs_index, set([]) + ).union(lhs_bwd_slice) + thread_size_to_ops[rhs_index] = thread_size_to_ops.get( + rhs_index, set([]) + ).union(rhs_bwd_slice) + + # Go through each index-size buckets, and apply the index-size to ops in the bucket. + cummulative_set = set() + for target_index_size, target_ops in thread_size_to_ops.items(): + # Ensure that we do not have any conflicts. + if not cummulative_set.isdisjoint(target_ops): + raise NotImplementedError("NYI: Handling of conflicting thread shape.") + cummulative_set = cummulative_set.union(target_ops) + for user in target_ops: + custom_user = get_custom(user) + set_index_size(custom_user, target_index_size) diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index 9ea9adad..5e221f52 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -494,28 +494,33 @@ def get_inputs( Return the inputs of a node, propagating through reductions. """ inputs = [] - for input in node.all_input_nodes: - custom = get_custom(input) - if isinstance(custom, GetResult): - reduction = custom.value - assert isinstance( - reduction, Reduction - ), "GetResult must be used by a Reduction" - # Map get result to output - inputs.append(reduction.outputs[custom.res_idx]) - continue - if isinstance(custom, IterArg): - # Map iter args to init args - iter_arg_idx = reduction.iter_args.index(node) - inputs.append(reduction.init_args[iter_arg_idx]) - continue - inputs.append(input) + custom = get_custom(node) + if isinstance(custom, IterArg): + # Map iter args to init args + local_reduction = reduction + if reduction is None: + local_reduction = custom.parent_op() + iter_arg_idx = custom.get_iter_idx() + inputs.append(local_reduction.init_args[iter_arg_idx]) + elif isinstance(custom, GetResult): + reduction = get_custom(custom.value) + assert isinstance( + get_custom(reduction), Reduction + ), "GetResult must be used by a Reduction" + # Map get result to output + reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name] + inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx]) + else: + # Default handling for other ops. + for input in node.all_input_nodes: + inputs.append(input) return inputs, reduction def bfs( node: fx.Node, get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]], + filter_fn: Callable[[fx.node], bool], ) -> set[fx.Node]: """ Run BFS on the graph to capture the forward slice of a node. @@ -529,25 +534,29 @@ def bfs( s = queue.pop(0) neighbors, reduction = get_neighbors(s, reduction) for neighbor in neighbors: - if neighbor not in visited: + if neighbor not in visited and filter_fn(neighbor): visited.add(neighbor) queue.append(neighbor) return visited -def capture_forward_slice(node: fx.Node) -> set[fx.Node]: +def capture_forward_slice( + node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True +) -> set[fx.Node]: """ Run BFS on the graph to capture the forward slice of a node. """ - return bfs(node, lambda x, y: get_users(x, y)) + return bfs(node, lambda x, y: get_users(x, y), filter_fn) -def capture_backward_slice(node: fx.Node) -> set[fx.Node]: +def capture_backward_slice( + node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True +) -> set[fx.Node]: """ Capture backward slice from a node and return the tree. Assumes graph is directed. """ - return bfs(node, lambda x, y: get_inputs(x, y)) + return bfs(node, lambda x, y: get_inputs(x, y), filter_fn) def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]: diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index fde0c792..202cdd92 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -39,7 +39,7 @@ from ..ops.wave_ops import Reduction, CustomOp, get_custom from .index_sequence_analysis import partition_strided_operators from .shared_memory_indexing import apply_shared_memory_indexing_corrections -from .register_analysis import determine_register_shape +from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr import shark_turbine.kernel.lang as tkl @@ -227,9 +227,6 @@ def _trace_and_get_kernel_signature( # Clean up chains of GetResults remove_chained_getresult(graph) - # Register analysis to determine register shapes. - determine_register_shape(graph, self.constraints) - # Optimizations. minimize_global_loads(graph, self.constraints) @@ -239,6 +236,9 @@ def _trace_and_get_kernel_signature( # Partition strided operators. partition_strided_operators(graph, self.constraints) + # Analyze Thread Shapes per Op. + determine_thread_shapes(graph) + # Decompose reduce Ops. decompose_reduce_ops(graph, self.constraints, idxc.subs)