Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Add torch.compile to PatchPredictor #776

Merged
merged 33 commits into from
Mar 19, 2024

Conversation

Abdol
Copy link
Collaborator

@Abdol Abdol commented Jan 26, 2024

This mini-PR adds torch.compile functionality to PatchPredictor.

@Abdol Abdol self-assigned this Jan 26, 2024
@Abdol Abdol added this to the Release v2.0.0 milestone Jan 26, 2024
@Abdol Abdol added the enhancement New feature or request label Jan 26, 2024
@Abdol Abdol marked this pull request as ready for review January 30, 2024 14:40
@measty
Copy link
Collaborator

measty commented Feb 1, 2024

At the moment the model is compiled before it is sent to the GPU (if GPU is being used). I think at least some of what torch.compile does is device-aware, so it may be better to compile after it is sent to device. Have you tried testing if the ordering makes a difference?

@shaneahmed
Copy link
Member

At the moment the model is compiled before it is sent to the GPU (if GPU is being used). I think at least some of what torch.compile does is device-aware, so it may be better to compile after it is sent to device. Have you tried testing if the ordering makes a difference?

At the moment the model is compiled before it is sent to the GPU (if GPU is being used). I think at least some of what torch.compile does is device-aware, so it may be better to compile after it is sent to device. Have you tried testing if the ordering makes a difference?

Agree, this should be checked.

@shaneahmed
Copy link
Member

Please enable tests for this branch.

Copy link

codecov bot commented Feb 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.89%. Comparing base (b2f57ee) to head (150678b).

Additional details and impacted files
@@                  Coverage Diff                   @@
##           enhance-torch-compile     #776   +/-   ##
======================================================
  Coverage                  99.89%   99.89%           
======================================================
  Files                         69       69           
  Lines                       8578     8589   +11     
  Branches                    1641     1642    +1     
======================================================
+ Hits                        8569     8580   +11     
  Misses                         1        1           
  Partials                       8        8           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Abdol
Copy link
Collaborator Author

Abdol commented Feb 23, 2024

At the moment the model is compiled before it is sent to the GPU (if GPU is being used). I think at least some of what torch.compile does is device-aware, so it may be better to compile after it is sent to device. Have you tried testing if the ordering makes a difference?

@measty I believe the model is compiled only when the forward function is called. See link 1 and link 2.

@Abdol
Copy link
Collaborator Author

Abdol commented Feb 27, 2024

@shaneahmed torch.compile is not compatible with Python 3.12 (see here). This has triggered an error when running CI with Python 3.12:

pytorch/pytorch#120233

if sys.version_info >= (3, 12):
raise RuntimeError("Dynamo is not supported on Python 3.12+")
E RuntimeError: Dynamo is not supported on Python 3.12+

Should we disable torch.compile for this version in the PR?

@Abdol Abdol merged commit 252c7f9 into enhance-torch-compile Mar 19, 2024
15 checks passed
@Abdol Abdol deleted the enhance-torch-compile-patch-predictor branch March 19, 2024 17:50
@shaneahmed shaneahmed restored the enhance-torch-compile-patch-predictor branch March 22, 2024 14:08
@shaneahmed shaneahmed deleted the enhance-torch-compile-patch-predictor branch March 22, 2024 14:08
Compiled model.

"""
if disable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need this variable. I think we should only call this function if not disable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @shaneahmed. That could be done, too. However, I'm mirroring the PyTorch implementation, which includes a disable flag in the function (torch.compile).

self.model = (
compile_model( # for runtime, such as after wrapping with nn.DataParallel
model,
mode=rcParam["torch_compile_mode"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need rcparam for this? We can just set this as kwargs argument in the engines.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @shaneahmed. Having kwargs for torch_compile_mode would work, too. I may suggest to keep rcParam for now until we implement it in the new engine design. Happy to discuss it in our next meeting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants