Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State of affairs for NestedTensor (NJT) inference? #4234

Open
vadimkantorov opened this issue Nov 2, 2024 · 5 comments
Open

State of affairs for NestedTensor (NJT) inference? #4234

vadimkantorov opened this issue Nov 2, 2024 · 5 comments
Labels
triaged Issue has been triaged by maintainers

Comments

@vadimkantorov
Copy link

PyTorch now has some support for representing varlen sequences. It is supported to some extent by HF:

This is useful e.g. for saving compute on padding tokens for BERT inference. Does TRT has kernels for such NJT sdpa ops? (and can they be executed via CUDA graphs?) If so, how to benefit from it? Is there an example?

Thank you!

@vadimkantorov vadimkantorov changed the title State of affairs for NestedTensor (NJT) inference State of affairs for NestedTensor (NJT) inference? Nov 2, 2024
@lix19937
Copy link

lix19937 commented Nov 5, 2024

In my opinion, trt has no such op(njt), you can custom write it.

If you want to use cudagraph, you need set the max-len of sequence (use the fixed address of sequences to build the static graph), and set the min, opt, max shape for this input.

@vadimkantorov
Copy link
Author

vadimkantorov commented Nov 5, 2024

A key component of NJT support for SDPA are block-diagonal masks. Does TRT have support/examples for block-diagonal attn masks?

Because one would want to have proper FlashAttention kernels in this setup, otherwise the speedups likely may not be realized...

@poweiw
Copy link
Collaborator

poweiw commented Nov 5, 2024

@zhenhuaw-me Can you take a look?

@poweiw poweiw added the triaged Issue has been triaged by maintainers label Nov 5, 2024
@vadimkantorov
Copy link
Author

The relevant documentation in Triton Inference Server on ragged batch support: https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/ragged_batching.html

So it would be good to have (end-to-end, starting from a PyTorch model, then export and configuring TRT engine file with trex visualizations) examples of optimized attention modules for transformer inference on such varlen sequences in TRT...

@vadimkantorov
Copy link
Author

vadimkantorov commented Nov 8, 2024

These kernels appear available from older FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/docs/bert_guide.md#model-architecture or in https://github.com/bytedance/effective_transformer

It would be good to upstream these "EffectiveTransformer kernels with TensorRT" given that FasterTransformer EOL'd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants