Skip to content

Commit

Permalink
Add exporting with argument device affinities (#300)
Browse files Browse the repository at this point in the history
We don't have handling of device affinties for our sharded tensor types
when they are the arguments of exported functions. Although, they can be
specified explicitly when exporting, this change adds the ability to
hide this and deduces the affinties from the tensor types.
  • Loading branch information
sogartar authored Oct 23, 2024
1 parent ad0ac57 commit 3b5dc8a
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 4 deletions.
156 changes: 156 additions & 0 deletions sharktank/sharktank/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable, Any
import torch
from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder
from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten
from .types.tensors import ShardedTensor
from torch.utils._pytree import PyTree, _is_leaf


def flatten_signature(
*sample_args: list[PyTree],
) -> Callable[[Callable], Any]:
"""Decorator that flattens the signature of a function using PyTorch's type
registration.
It will flatten the same way torch PyTorch does, returning a function that accepts
and returns a flat list of torch.Tensor.
The decorator requires sample arguments of the unflattened function.
```
@flatten_signature(
{
"a1": SplitPrimitiveTensor(ts=[torch.tensor([1])], shard_dim=0),
"a2": torch.tensor([2]),
},
[DefaultPrimitiveTensor(data=torch.tensor([3]))]
)
def f(a, b):
return a["a1"], b
```
will result in a function with signature
```
(
torch.Tensor of size 1,
torch.Tensor of size 2,
torch.Tensor of size 3,
) -> (
torch.Tensor of size 1,
torch.Tensor of size 2,
)
```
"""
tree_spec = tree_structure(sample_args)

def _decorator(f: Callable) -> Callable:
def _wrapper(*flat_args: list[Any]) -> list[Any]:
unflattended_args = tree_unflatten(flat_args, tree_spec)
return tree_flatten(f(*unflattended_args))[0]

return _wrapper

return _decorator


def get_argument_flat_device_affinities(
*args: list[PyTree],
) -> dict[int, DeviceAffinity]:
"""Return the flat device affinities for unflattened arguments.
ShardedTensor types have their device affinities assigned.
All other arguments are left unassigned.
```
get_argument_flat_device_affinities(
torch.Tensor([1]),
[ReplicatedTensor(ts=[torch.tensor([2]), torch.tensor([3])])]
)
```
returns
```
{
1: DeviceAffinity("0"),
2: DeviceAffinity("1"),
}
```
"""

def is_leaf(v: PyTree) -> bool:
if isinstance(v, ShardedTensor):
return True
# TODO: It is sad _is_leaf is private. Find a way not use it.
from torch.utils._pytree import _is_leaf

return _is_leaf(v)

# flattened up to a sharded tensor.
flat_args_up_to_sharded_tensor = tree_flatten(args, is_leaf=is_leaf)[0]
nested_device_affinities: list[list[DeviceAffinity | None]] = [
[DeviceAffinity(f"{shard_idx}") for shard_idx in range(len(arg.shards))]
if isinstance(arg, ShardedTensor)
else [None]
for arg in flat_args_up_to_sharded_tensor
]
flat_device_affinities: list[DeviceAffinity | None] = [
affinity
for affinity_list in nested_device_affinities
for affinity in affinity_list
]
return {
arg_idx: affinity
for arg_idx, affinity in enumerate(flat_device_affinities)
if affinity is not None
}


def export(
f: Callable,
/,
fx_builder: FxProgramsBuilder | None = None,
args: tuple[PyTree] | None = None,
arg_device: dict[int, DeviceAffinity] | None = None,
*transitive_args,
**transitive_kwargs,
) -> torch.export.ExportedProgram:
"""Wrapper around FxProgramsBuilder.export_program that handles
the sharktank custom tensor types.
If `arg_device` is not specified it will extract the affinities
from the passed `args`.
`arg_device` must pass the affinities for the flattened arguments.
These are those that correspond to torch.Tensor.
For example a sharded tensor with 2 shards would result in 2 arguments in the MLIR
signature."""
if args is None:
args = []
if arg_device is None:
arg_device = get_argument_flat_device_affinities(*args)
flat_args = tree_flatten(args)[0]
if fx_builder is not None:
# Flatten the signature of the function.
# Technically this is done during export, but we want the signature to match
# the flat device affinities.
def module_fn_with_flat_signature(module, *flat_args):
@flatten_signature(*args)
def flat_fn(*args):
return f(module, *args)

return flat_fn(*flat_args)

amended_kwargs = dict(**transitive_kwargs)
if "name" not in amended_kwargs or amended_kwargs["name"] is None:
amended_kwargs["name"] = f.__name__
return fx_builder.export_program(
module_fn_with_flat_signature,
*transitive_args,
args=flat_args,
arg_device=arg_device,
**amended_kwargs,
)

assert False, "TODO: implement the case when not using an FxProgramsBuilder"
16 changes: 12 additions & 4 deletions sharktank/sharktank/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any, List
from typing import Any, Callable, List
from collections.abc import Iterable
from itertools import zip_longest
from operator import eq


def longest_equal_range(l1: List[Any], l2: List[Any]) -> int:
Expand All @@ -18,5 +18,13 @@ def longest_equal_range(l1: List[Any], l2: List[Any]) -> int:
return min(len(l1), len(l2))


def iterables_equal(iterable1: Iterable, iterable2: Iterable) -> bool:
return all(v1 == v2 for v1, v2 in zip_longest(iterable1, iterable2))
def iterables_equal(
iterable1: Iterable,
iterable2: Iterable,
*,
elements_equal: Callable[[Any, Any], bool] | None = None
) -> bool:
elements_equal = elements_equal or eq
return all(
elements_equal(v1, v2) for v1, v2 in zip(iterable1, iterable2, strict=True)
)
40 changes: 40 additions & 0 deletions sharktank/sharktank/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import tempfile
import unittest
import torch
from typing import Any, Callable
from operator import eq
from collections.abc import Iterable

from ..types import *

Expand Down Expand Up @@ -99,6 +102,30 @@ def get_best_torch_device() -> str:
return "cpu"


def assert_dicts_equal(
dict1: dict, dict2: dict, *, values_equal: Callable[[Any, Any], bool] | None = None
) -> None:
values_equal = values_equal or eq
assert len(dict1) == len(
dict2
), f"Dictionaries not equal. {dict1} and {dict2} have different number of elements {len(dict1)} != {len(dict2)}"
for k, v1 in dict1.items():
assert (
k in dict2
), f"Dictionaries {dict1} and {dict2} not equal. Key {k} not found in {dict2}"
v2 = dict2[k]
assert values_equal(
v1, dict2[k]
), f"Dictionaries {dict1} and {dict2} not equal for key {k}. Values {v1} and {v2} not equal"


def assert_equal(
a: Any, b: Any, *, equal: Callable[[Any, Any], bool] | None = None
) -> None:
equal = equal or eq
assert equal(a, b), f"{a} and {b} are not equal"


def assert_golden_safetensors(actual_path, ref_path):
"""Asserts that actual and reference safetensors files are within tolerances."""
from safetensors import safe_open
Expand Down Expand Up @@ -133,3 +160,16 @@ def print_stats(label, t):
actual = actual_f.get_tensor(name)
ref = ref_f.get_tensor(name)
torch.testing.assert_close(actual, ref, msg=name)


def assert_iterables_equal(
iterable1: Iterable,
iterable2: Iterable,
*,
elements_equal: Callable[[Any, Any], bool] | None = None,
) -> None:
elements_equal = elements_equal or eq
for i, (v1, v2) in enumerate(zip(iterable1, iterable2, strict=True)):
assert elements_equal(
v1, v2
), f"Iterables not equal at index {i} for elements {v1} and {v2}"
101 changes: 101 additions & 0 deletions sharktank/tests/export_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from sharktank.types import (
ReplicatedTensor,
SplitPrimitiveTensor,
DefaultPrimitiveTensor,
unbox_tensor,
)
from sharktank.export import (
export,
flatten_signature,
get_argument_flat_device_affinities,
)
from sharktank import ops
from sharktank.utils.testing import (
assert_equal,
assert_iterables_equal,
assert_dicts_equal,
)
from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder
from iree.turbine import aot
from unittest import TestCase
import torch


class ExportTest(TestCase):
def testFlattenSignature(self):
expected_a = [SplitPrimitiveTensor(ts=[torch.tensor([1])], shard_dim=0)]
expected_b = {"element": DefaultPrimitiveTensor(data=torch.tensor([2]))}
expected_c = torch.tensor([3])

@flatten_signature(expected_a, expected_b, expected_c)
def f(
a: list[SplitPrimitiveTensor],
b: dict[str, DefaultPrimitiveTensor],
c: torch.Tensor,
):
assert_iterables_equal(a, expected_a, elements_equal=ops.equal)
assert_dicts_equal(b, expected_b, values_equal=ops.equal)
assert_equal(c, expected_c, equal=ops.equal)

f(
unbox_tensor(expected_a[0].shards[0]),
expected_b["element"].as_torch(),
expected_c,
)

def testGetFlatArgumentDeviceAffinities(self):
args = [
{
"a": [
SplitPrimitiveTensor(
ts=[torch.tensor([1]), torch.tensor([2])], shard_dim=0
)
]
},
torch.tensor([3]),
ReplicatedTensor(ts=[torch.tensor([4]), torch.tensor([5])]),
]
affinities = get_argument_flat_device_affinities(*args)
expected_affinities = {
0: DeviceAffinity("0"),
1: DeviceAffinity("1"),
3: DeviceAffinity("0"),
4: DeviceAffinity("1"),
}
assert_dicts_equal(affinities, expected_affinities)

def testExportWithArgumentDeviceAffinities(self):
args = (ReplicatedTensor(ts=[torch.tensor([1])]), torch.tensor([[2]]))

class Module(torch.nn.Module):
def f(self, a, b):
return a, b

module = Module()
fxb = FxProgramsBuilder(module)
export(
Module.f,
fx_builder=fxb,
args=args,
strict=False,
)
export_output = aot.export(
fxb,
)
asm = str(export_output.mlir_module)
print(asm)
self.assertRegex(
asm,
expected_regex=(
"func.func @f\\("
"%.+: !torch.vtensor<\\[1\\],si64> "
"{iree.abi.affinity = #hal.device.promise<@__device_0>}, "
"%.+: !torch.vtensor<\\[1,1\\],si64>\\)"
),
)

0 comments on commit 3b5dc8a

Please sign in to comment.