Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device affinities for arguments in AOT #231

Merged
merged 1 commit into from
Oct 23, 2024

Conversation

sogartar
Copy link
Contributor

We don't have support for providing device affinities for function arguments, which need to end up as MLIR function argument attributes.

This change adds a class DeviceAffinity and provides the ability to supply affinities when exporting Torch functions/modules or when tracing in IREE-Trubine itself.

@sogartar sogartar force-pushed the affinity-attributes-in-api branch 7 times, most recently from ce0da33 to 00d8857 Compare October 18, 2024 19:24
@sogartar sogartar changed the title WIP Add device affinities for arguments in AOT Add device affinities for arguments in AOT Oct 18, 2024
@sogartar sogartar marked this pull request as ready for review October 18, 2024 19:36
@@ -107,12 +112,27 @@ def __call__(self, *args, **kwargs):
return self.py_value(*args, **kwargs)


class ExportTargetDef:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should drop the ExportTargetDef and use half-initialized ExportProcDef and ExportedProgramDef directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the benefit to having this separate class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the place where we need a structure to store this data we are not ready to construct ExportProcDef or ExportedProgramDef. If we are to use them directly they will have a more complicated multi-step initialization.

@sogartar
Copy link
Contributor Author

This PR removes the need for #220, which kind of abuses the function generation mechanism.

@stellaraccident
Copy link
Collaborator

Thanks. I will review this in a few minutes

argument_device_affinities: dict[int, "DeviceAffinity"] | None = None,
):
self.target = target
self.argument_device_affinities = argument_device_affinities
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make the name more succinct. Instead of argument_device_affinities just arg_device. It should focus on where the argument is placed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -207,6 +240,21 @@ def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]:
) # type: ignore

def def_attribute(self, key, value):
if isinstance(value, ExportTargetDef):
if isinstance(value.target, ExportedProgram):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swap the if else and remove the else: part. You can do

if not isinstance(value.target, ExportedProgram):
                # We expect exported function.
                assert callable(value.target) and inspect.isfunction(value.target)
                return self.def_export_proc(
                    key, value.target, value.argument_device_affinities
                )

And given the return it would exit right away anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -633,6 +708,7 @@ def __new__(
ep_def.exported_program,
symbol_name=ep_def.export_name or "main",
symbol_visibility=None if ep_def.public else "private",
argument_device_affinities=ep_def.argument_device_affinities or {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just support whatever the default is for ep_def.argument_device_affinities rather than using or

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not wipe out the existing test. Create a separate fx_programs_test_device.py that tests the device affinity work. We should try to guarantee the old patch works for as long as possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the new test to a new file as you suggested.

@@ -207,6 +240,21 @@ def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]:
) # type: ignore

def def_attribute(self, key, value):
if isinstance(value, ExportTargetDef):
if isinstance(value.target, ExportedProgram):
value = ExportedProgramDef(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be returned as well? It is setting the value but this falls through. I see it is handled further down, if so we should handle it right before the follow up case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this case down before the handling of ExportedProgramDef.

class DeviceAffinity:
"""This is used to provide device affinities to exported function arguments."""

def __init__(self, moniker: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make the moniker an int. All cases around just specify it that way anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to ordinal: int.

@@ -568,6 +627,22 @@ def save_mlir(inst: "CompiledModule", path: Union[Path, str]):

jittable = staticmethod(builtins.jittable)

@staticmethod
def annotate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name this something better than annotate. Even signature_info would be sufficient to explain that its adding additional information to the function signature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -61,7 +68,7 @@ class FxPrograms:
"""

def __init__(self):
self.programs: dict[str, torch.export.ExportedProgram] = {}
self.programs: dict[str, ExportTargetDef] = {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rsuderman This is something I forgot to bring to attention, but I am not sure if self.programs should be a part of the interface. This changes it and I also changed one test that specifically used it.

Copy link
Collaborator

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. A minor comment on code organization.

@@ -49,6 +50,21 @@
SaveableTarget = Union[str, Path, None, Output]


class DeviceAffinity:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this in tensor_traits.py. Then you won't need to work around weird circular references and can just cleanly import it from anywhere needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it there. Why I did not see this.

We don't have support for providing device affinities for function
arguments, which need to end up as MLIR function argument attributes.

This change adds a class DeviceAffinity and provides the ability to
supply affinities when exporting Torch functions/modules or when
tracing in IREE-Trubine itself.

Signed-off-by: Boian Petkantchin <[email protected]>
@sogartar
Copy link
Contributor Author

I squashed and rebased to prepare for merging.

@sogartar sogartar merged commit ae9a51c into iree-org:main Oct 23, 2024
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants