Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to remove signal and wait layer in the engine? #4232

Open
lijinghaooo opened this issue Oct 31, 2024 · 5 comments
Open

How to remove signal and wait layer in the engine? #4232

lijinghaooo opened this issue Oct 31, 2024 · 5 comments
Labels
triaged Issue has been triaged by maintainers

Comments

@lijinghaooo
Copy link

Description

Using trt llm to generate llama classification model. I have two similar script to generate engine, the first is raw scripts, the second is base on example/llama/build.sh script.
However, the second engine is slower than the first engine, so I dump engine layer, there are many signal and wait layer(as the images below show) in the second engine. It seems happen at type cast.
Any idea why generate signal and wait layer and how to work around this layer.

Image
Image

Environment

TensorRT Version:
9.3.0
NVIDIA GPU:
L20
NVIDIA Driver Version:
535.161.08
CUDA Version:
12.2
CUDNN Version:
8.9.6

Operating System:
Ubuntu 22.04.3 LTS
Python Version (if applicable):
3.10.12
Tensorflow Version (if applicable):
no
PyTorch Version (if applicable):
2.2.2
Baremetal or Container (if so, version):
container nvidia/cuda:12.1.0-devel-ubuntu22.04

Relevant Files

Model link:

Steps To Reproduce

Commands or scripts:

Have you tried the latest release?: no, there are high cost to upgrade trt version

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):

@lix19937
Copy link

lix19937 commented Nov 5, 2024

It seems that signal and wait layer are own to myelin. you can upload your two scripts.

@lijinghaooo
Copy link
Author

@lix19937
the first script without wait and signal layer is in here: scripts.tar.gz

the second script is official example scripts(https://github.com/NVIDIA/TensorRT-LLM/blob/v0.9.0/examples/llama/convert_checkpoint.py), there are some different with output because the model is a llama SequenceClassification model, the follow is diff with llama model

diff --git a/tensorrt_llm/models/llama/model.py b/tensorrt_llm/models/llama/model.py
index dc6ed439..1e620675 100644
--- a/tensorrt_llm/models/llama/model.py
+++ b/tensorrt_llm/models/llama/model.py
@@ -20,7 +20,7 @@ from ..._utils import pad_vocab_size
 from ...functional import Tensor, recv, send
 from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
                        Embedding, GatedMLP, MoeConfig, PositionEmbeddingType,
                        RmsNorm)
 from ...lora_manager import LoraBuildConfig, use_lora
 from ...mapping import Mapping
 from ...module import Module
@@ -50,7 +50,7 @@ class LLaMADecoderLayer(Module):
             num_kv_heads=config.num_key_value_heads,
             max_position_embeddings=config.max_position_embeddings,
             dtype=config.dtype,
-            attention_mask_type=AttentionMaskType.causal,
+            attention_mask_type=AttentionMaskType.bidirectional,
             bias=config.attn_bias,
             position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
             rotary_embedding_base=config.rotary_base,
@@ -60,6 +60,21 @@ class LLaMADecoderLayer(Module):
             tp_rank=config.mapping.tp_rank,
             quant_mode=config.quant_mode)

         mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

         ClsMLP = GatedMLP
@@ -104,7 +119,8 @@ class LLaMADecoderLayer(Module):
         residual = hidden_states
         hidden_states = self.input_layernorm(hidden_states)

         attention_output = self.attention(
             hidden_states,
             attention_mask=attention_mask,
             medusa_packed_mask=medusa_packed_mask,  # For Medusa support
@@ -113,6 +129,8 @@ class LLaMADecoderLayer(Module):
             kv_cache_params=kv_cache_params,
             attention_params=attention_params,
             lora_layer_params=lora_layer_params)


         if use_cache:
             attention_output, presents = attention_output
@@ -153,7 +171,6 @@ class LLaMAModel(Module):
     def forward(
             self,
             input_ids,
-            position_ids=None,
             use_cache=False,
             attention_mask=None,
             medusa_position_offsets=None,  # For Medusa support
@@ -205,9 +222,10 @@ class LLaMAForCausalLM(DecoderModelForCausalLM):
         transformer = LLaMAModel(config)
         vocab_size_padded = pad_vocab_size(config.vocab_size,
                                            config.mapping.tp_size)
+        label_num = 5
         if config.mapping.is_last_pp_rank():
             lm_head = ColumnLinear(config.hidden_size,
-                                   vocab_size_padded,
+                                   label_num,
                                    bias=False,
                                    dtype=config.dtype,
                                    tp_group=config.mapping.tp_group,

the following is build command.

python3 convert_checkpoint.py \
  --model_dir pytorch_model/ \
  --output_dir checkpoints/ \
  --dtype float16

trtllm-build \
 --checkpoint_dir checkpoints/ \
 --output_dir engines/ \
 --gpt_attention_plugin disable \
 --gemm_plugin disable \
 --remove_input_padding disable \
 --paged_kv_cache disable \
 --max_batch_size 2 \
 --max_input_len 1300 \
 --max_output_len 1 \
 --gpus_per_node 1 \
 --profiling_verbosity

@poweiw
Copy link
Collaborator

poweiw commented Nov 5, 2024

@zerollzeng Can you take a look?

@poweiw poweiw added the triaged Issue has been triaged by maintainers label Nov 5, 2024
@lix19937
Copy link

lix19937 commented Nov 7, 2024

@lijinghaooo SequenceClassification model need last results then as current input, can you modify the the second script and use the same model to check.

@lijinghaooo
Copy link
Author

lijinghaooo commented Nov 11, 2024

@lix19937 Thank you for your reply!
In this case, the first token is needed instead of get the last token. so I add select and some other layers at the end of class DecoderModelForCausalLM, but the signal and wait layer also exist.

Is there other insights/ to work around this layer?

--- a/tensorrt_llm/models/modeling_utils.py
+++ b/tensorrt_llm/models/modeling_utils.py
@@ -12,7 +12,7 @@ import torch
 from .._common import default_net
 from .._utils import (numpy_to_torch, release_gc, str_dtype_to_torch,
                       str_dtype_to_trt, trt_dtype_to_torch)
-from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits
+from ..functional import PositionEmbeddingType, Tensor, gather_first_token_logits, softmax, select
 from ..layers import (AttentionParams, Embedding, FusedGatedMLP, GatedMLP,
                       KeyValueCacheParams, LoraParams, PromptTuningEmbedding)
 from ..layers.attention import Attention, BertAttention
@@ -490,35 +490,26 @@ class PretrainedModel(Module,
             max_draft_len=max_draft_len,
             lora_target_modules=lora_target_modules,
             multiple_profiles=multiple_profiles,
-            streamingllm=streamingllm)
+            streamingllm=streamingllm,
+            use_cache=use_cache)

         result = {
             'input_ids':
             model_inputs['input_ids'],
-            'position_ids':
-            model_inputs['position_ids'],
             'use_cache':
-            True,
+            use_cache,
             'last_token_ids':
             model_inputs['last_token_ids'],
             'attention_mask':
             model_inputs['attention_mask'],
             'kv_cache_params':
             KeyValueCacheParams(
-                past_key_value=model_inputs['past_key_value'],
-                host_past_key_value_lengths=model_inputs[
-                    'host_past_key_value_lengths'],
                 host_max_attention_window_sizes=model_inputs[
                     'host_max_attention_window_sizes'],
                 host_sink_token_length=model_inputs['host_sink_token_length'],
-                kv_cache_block_pointers=model_inputs['kv_cache_block_pointers'],
-                host_kv_cache_block_pointers=model_inputs[
-                    'host_kv_cache_block_pointers'],
-                cache_indirection=model_inputs['cache_indirection'],
             ),
             'attention_params':
             AttentionParams(
-                sequence_length=model_inputs['sequence_length'],
                 context_lengths=model_inputs['context_lengths'],
                 host_context_lengths=model_inputs['host_context_lengths'],
                 max_context_length=max_input_len,
@@ -590,7 +581,6 @@ class DecoderModelForCausalLM(PretrainedModel):

     def forward(self,
                 input_ids: Tensor,
-                position_ids=None,
                 use_cache=False,
                 last_token_ids=None,
                 attention_mask=None,
@@ -605,7 +595,6 @@ class DecoderModelForCausalLM(PretrainedModel):
                 medusa_packed_mask=None):
         kwargs = {
             'input_ids': input_ids,
-            'position_ids': position_ids,
             'use_cache': use_cache,
             'attention_mask': attention_mask,
             'kv_cache_params': kv_cache_params,
@@ -633,12 +622,17 @@ class DecoderModelForCausalLM(PretrainedModel):
             hidden_states, presents = hidden_states

         if self.config.mapping.is_last_pp_rank():
-            hidden_states = gather_last_token_logits(
-                hidden_states, last_token_ids,
-                default_net().plugin_config.remove_input_padding)
+            if default_net().plugin_config.remove_input_padding:
+                hidden_states = gather_first_token_logits(
+                    hidden_states, last_token_ids,
+                    default_net().plugin_config.remove_input_padding)
+            else:
+                hidden_states = select(hidden_states, 1, 0)

             # [batch_size, hidden_size] -> [batch_size, vocab_size]
             lm_logits = self.lm_head(hidden_states)
+            lm_logits = softmax(lm_logits, dim = -1)
+            lm_logits = select(lm_logits, 1, 0)
             lm_logits.mark_output('logits', self.config.logits_dtype)
         else:
             hidden_states.mark_output('hidden_states_output', self.config.dtype)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants