diff --git a/docs/source/nlp/quantization.rst b/docs/source/nlp/quantization.rst index 747938bebedd..500c37dcfb26 100644 --- a/docs/source/nlp/quantization.rst +++ b/docs/source/nlp/quantization.rst @@ -103,7 +103,7 @@ The TensorRT-LLM engine can be conveniently built and run using ``TensorRTLLM`` .. code-block:: python - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/to/trt_llm_engine_folder") diff --git a/nemo/deploy/deploy_pytriton.py b/nemo/deploy/deploy_pytriton.py index 25e09cf3eacc..1e1333f03b55 100644 --- a/nemo/deploy/deploy_pytriton.py +++ b/nemo/deploy/deploy_pytriton.py @@ -29,7 +29,7 @@ class DeployPyTriton(DeployBase): Example: from nemo.deploy import DeployPyTriton, NemoQueryLLM - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files") trt_llm_exporter.export( diff --git a/nemo/deploy/nlp/__init__.py b/nemo/deploy/nlp/__init__.py index ae4db1ce6f2a..a2110931c6df 100644 --- a/nemo/deploy/nlp/__init__.py +++ b/nemo/deploy/nlp/__init__.py @@ -19,4 +19,8 @@ except Exception: use_query_llm = False -from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +use_megatron_llm = True +try: + from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +except Exception: + use_megatron_llm = False diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py index 55712d98852c..d9155f923f18 100644 --- a/nemo/export/__init__.py +++ b/nemo/export/__init__.py @@ -11,15 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -import logging - -LOGGER = logging.getLogger("NeMo") - - -use_TensorRTLLM = True -try: - from nemo.export.tensorrt_llm import TensorRTLLM -except Exception as e: - LOGGER.warning("TensorRTLLM could not be imported.") diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py b/nemo/export/sentencepiece_tokenizer.py similarity index 93% rename from nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py rename to nemo/export/sentencepiece_tokenizer.py index 1f86c5887a5e..e47b1c665af5 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py +++ b/nemo/export/sentencepiece_tokenizer.py @@ -22,7 +22,7 @@ class SentencePieceTokenizer: """ - Sentencepiecetokenizer https://github.com/google/sentencepiece + SentencePieceTokenizer https://github.com/google/sentencepiece Args: model_path: path to sentence piece tokenizer model. @@ -247,3 +247,21 @@ def vocab(self): for i in range(self.vocab_size - self.original_vocab_size) ] return main_vocab + special_tokens + + ### Below are a few methods that mimic transformers.PreTrainedTokenizer for vLLM + + def convert_ids_to_tokens(self, ids, skip_special_tokens: bool = False): + return self.ids_to_tokens(ids) # TODO: support skip_special_tokens + + def convert_tokens_to_string(self, tokens: List[str]): + return self.tokens_to_text(tokens) + + def __len__(self): + return self.vocab_size + + @property + def is_fast(self): + return True + + def get_added_vocab(self): + return None diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 7cc92f0ca588..d03617fc2c3b 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -68,7 +68,7 @@ class TensorRTLLM(ITritonDeployable): Exports nemo checkpoints to TensorRT-LLM and run fast inference. Example: - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files") trt_llm_exporter.export( diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py b/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py index c9c6f65d27e0..d9155f923f18 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py @@ -11,6 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 09eae628999a..1d473f497f51 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -28,8 +28,8 @@ from torch.distributed.checkpoint import FileSystemReader from transformers import AutoTokenizer, PreTrainedTokenizer +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.export.tarutils import TarPath, ZarrPathStore -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer LOGGER = logging.getLogger("NeMo") diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index 4b0775a0aa2a..c3dd5c2befc9 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -17,7 +17,7 @@ from omegaconf import OmegaConf from transformers import AutoTokenizer -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer # TODO: use get_nmt_tokenizer helper below to instantiate tokenizer once environment / dependencies get stable # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer diff --git a/nemo/export/vllm/__init__.py b/nemo/export/vllm/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/export/vllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/export/vllm/engine.py b/nemo/export/vllm/engine.py new file mode 100644 index 000000000000..0a3600e7b1eb --- /dev/null +++ b/nemo/export/vllm/engine.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +from vllm import LLMEngine +from vllm.transformers_utils.tokenizer_group.tokenizer_group import TokenizerGroup + +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.export.tarutils import TarPath +from nemo.export.vllm.tokenizer_group import NemoTokenizerGroup + +LOGGER = logging.getLogger("NeMo") + + +class NemoLLMEngine(LLMEngine): + """ + Overrides some functionality from vllm.LLMEngine to use our custom tokenizer + instead of one from Transformers. + """ + + def _init_tokenizer(self, **tokenizer_init_kwargs): + # Find the tokenizer file name in the Nemo checkpoint config + tokenizer_config = self.model_config.nemo_model_config.get('tokenizer', {}) + tokenizer_model = tokenizer_config.get('model', tokenizer_config.get('tokenizer_model', None)) + + # If there is no tokenizer file specified but there's a reference to an HF tokenizer, use that + if tokenizer_model is None and tokenizer_config.get('library') == 'huggingface': + tokenizer_type = tokenizer_config.get('type') + if tokenizer_type is not None: + tokenizer_group = TokenizerGroup( + tokenizer_id=tokenizer_type, + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + ) + + # Update the HF config fields that come from the tokenizer in NeMo + self.model_config.hf_config.vocab_size = tokenizer_group.tokenizer.vocab_size + self.model_config.hf_config.bos_token_id = tokenizer_group.tokenizer.bos_token_id + self.model_config.hf_config.eos_token_id = tokenizer_group.tokenizer.eos_token_id + self.model_config.hf_config.pad_token_id = tokenizer_group.tokenizer.pad_token_id + + return tokenizer_group + + # Open the checkpoint archive + with TarPath(self.model_config.nemo_checkpoint) as archive: + tokenizer_model_file = None + if isinstance(tokenizer_model, str) and tokenizer_model.startswith('nemo:'): + tokenizer_model = tokenizer_model[len('nemo:') :] + tokenizer_model_file = archive / tokenizer_model + if not tokenizer_model_file.exists(): + LOGGER.warn( + f'Tokenizer model file {tokenizer_model} specified in the model_config does not ' + + 'exist in the checkpoint.' + ) + tokenizer_model_file = None + + if tokenizer_model_file is None: + for path in archive.glob('*tokenizer*.model'): + LOGGER.info(f'Found tokenizer model file {path}.') + tokenizer_model_file = path + break + + if tokenizer_model_file is None: + raise RuntimeError('No tokenizer model file found, aborting.') + + # Extract the tokenizer model file into the model directory, + # because sentencepiece cannot load it directly from TarPath. + extracted_tokenizer_model = Path(self.model_config.model) / 'tokenizer.model' + with tokenizer_model_file.open('rb') as infile: + with extracted_tokenizer_model.open('wb') as outfile: + outfile.write(infile.read()) + + # Construct the tokenizer object and wrapper + tokenizer = SentencePieceTokenizer(str(extracted_tokenizer_model)) + + # Determine if the model needs a bos token (which is not stored in Nemo checkpoints) + add_bos_token = self.model_config.model_converter.requires_bos_token() + + tokenizer_group = NemoTokenizerGroup(tokenizer, add_bos_token=add_bos_token) + + # Update the HF config fields that come from the tokenizer in NeMo + self.model_config.hf_config.vocab_size = tokenizer.vocab_size + self.model_config.hf_config.bos_token_id = tokenizer.bos_token_id + self.model_config.hf_config.eos_token_id = tokenizer.eos_token_id + self.model_config.hf_config.pad_token_id = tokenizer.pad_id + + return tokenizer_group diff --git a/nemo/export/vllm/model_config.py b/nemo/export/vllm/model_config.py new file mode 100644 index 000000000000..0a98a9180c1d --- /dev/null +++ b/nemo/export/vllm/model_config.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +import yaml +from transformers import AutoConfig +from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len +from vllm.transformers_utils.config import get_hf_text_config + +from nemo.export.tarutils import TarPath +from nemo.export.vllm.model_converters import get_model_converter + + +class NemoModelConfig(ModelConfig): + """ + This class pretents to be a vllm.config.ModelConfig (with extra fields) but skips + some of its initialization code, and initializes the configuration from a Nemo checkpoint instead. + """ + + def __init__( + self, + nemo_checkpoint: str, + model_dir: str, + model_type: str, + tokenizer_mode: str, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: bool = False, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 5, + disable_sliding_window: bool = False, + ) -> None: + # Don't call ModelConfig.__init__ because we don't want it to call + # transformers.AutoConfig.from_pretrained(...) + + # TODO: Do something about vLLM's call to _load_generation_config_dict in LLMEngine.__init__ + # because it calls transformers.GenerationConfig.from_pretrained(...), which tries to download things + + self.nemo_checkpoint = nemo_checkpoint + self.model = model_dir + self.model_type = model_type + self.tokenizer = None + self.tokenizer_mode = tokenizer_mode + self.skip_tokenizer_init = False + self.trust_remote_code = False + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.quantization_param_path = quantization_param_path + self.enforce_eager = enforce_eager + self.max_seq_len_to_capture = max_seq_len_to_capture + self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window + self.served_model_name = nemo_checkpoint + + self.model_converter = get_model_converter(model_type) + if self.model_converter is None: + raise RuntimeError(f'Unknown model type "{model_type}"') + + hf_to_nemo_dict = { + 'hidden_size': 'hidden_size', + 'intermediate_size': 'ffn_hidden_size', + 'num_hidden_layers': 'num_layers', + 'num_attention_heads': 'num_attention_heads', + 'num_key_value_heads': 'num_query_groups', + # 'hidden_act': 'activation', ## <- vLLM has good defaults for the models, nemo values are wrong + 'max_position_embeddings': ['max_position_embeddings', 'encoder_seq_length'], + 'rms_norm_eps': 'layernorm_epsilon', + 'attention_dropout': 'attention_dropout', + 'initializer_range': 'init_method_std', + 'norm_epsilon': 'layernorm_epsilon', + 'rope_theta': 'rotary_base', + 'use_bias': 'bias', + } + + with TarPath(nemo_checkpoint) as archive: + with (archive / "model_config.yaml").open("r") as model_config_file: + self.nemo_model_config = yaml.load(model_config_file, Loader=yaml.SafeLoader) + + hf_args = {} + for hf_arg, nemo_arg in hf_to_nemo_dict.items(): + if not isinstance(nemo_arg, list): + nemo_arg = [nemo_arg] + + for nemo_arg_option in nemo_arg: + value = self.nemo_model_config.get(nemo_arg_option) + if value is not None: + hf_args[hf_arg] = value + break + + self.model_converter.convert_config(self.nemo_model_config, hf_args) + + self.hf_config = AutoConfig.for_model(model_type, **hf_args) + + self.hf_config.architectures = [self.model_converter.get_architecture()] + if self.rope_scaling is not None: + self.hf_config['rope_scaling'] = rope_scaling + + self.hf_text_config = get_hf_text_config(self.hf_config) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + ) + self._verify_tokenizer_mode() + self._verify_embedding_mode() + self._verify_quantization() + self._verify_cuda_graph() diff --git a/nemo/export/vllm/model_converters.py b/nemo/export/vllm/model_converters.py new file mode 100644 index 000000000000..595ceecf0b18 --- /dev/null +++ b/nemo/export/vllm/model_converters.py @@ -0,0 +1,410 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional, Sequence, Tuple + +import torch + + +class ModelConverter(ABC): + """ + Abstract class that defines the interface for a converter that implements model-specific conversion functions + for deploying NeMo checkpoints on vLLM. + """ + + def __init__(self, model_type: str): + self.model_type = model_type + + @abstractmethod + def get_architecture(self) -> Optional[str]: + """ + Returns the HF architecture name for the current model, such as 'LlamaForCausalLM'. + """ + pass + + def convert_config(self, nemo_model_config: dict, hf_config: dict) -> None: + """ + Implements any custom HF configuration adjustments in the 'hf_config' dict that are necessary + for this model after the common translation takes place in NemoModelConfig's constructor. + """ + pass + + @abstractmethod + def convert_weights(self, nemo_model_config: dict, state_dict: dict) -> Sequence[Tuple[str, torch.tensor]]: + """ + Returns or yields a sequence of (name, tensor) tuples that contain model weights in the HF format. + """ + pass + + def requires_bos_token(self) -> bool: + """ + Returns True if the model requires a 'bos' token to be used at the beginning of the input sequence. + NeMo checkpoints do not store this information. + """ + return False + + +class LlamaConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'llama': + return 'LlamaForCausalLM' + if self.model_type == 'mistral': + return 'MistralForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + hidden_size = nemo_model_config["hidden_size"] + head_num = nemo_model_config["num_attention_heads"] + num_query_groups = nemo_model_config["num_query_groups"] + num_layers = nemo_model_config["num_layers"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + weight_name = f'model.layers.{layer}.self_attn.{name}.weight' + yield (weight_name, qkv_weights[slice].reshape(-1, hidden_size)) + + linear_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', linear_proj_weight) + + gate_proj_weight, up_proj_weight = torch.chunk( + state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer], 2, dim=0 + ) + yield (f'model.layers.{layer}.mlp.gate_proj.weight', gate_proj_weight) + yield (f'model.layers.{layer}.mlp.up_proj.weight', up_proj_weight) + + mlp_up_weight = state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer] + yield (f'model.layers.{layer}.mlp.down_proj.weight', mlp_up_weight) + + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attn_layernorm_weight = state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][layer] + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attn_layernorm_weight) + + def requires_bos_token(self): + return True + + +class MixtralConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'mixtral': + return 'MixtralForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + hidden_size = nemo_model_config["hidden_size"] + head_num = nemo_model_config["num_attention_heads"] + num_query_groups = nemo_model_config["num_query_groups"] + num_layers = nemo_model_config["num_layers"] + num_moe_experts = nemo_model_config["num_moe_experts"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + weight_name = f'model.layers.{layer}.self_attn.{name}.weight' + yield (weight_name, qkv_weights[slice].reshape(-1, hidden_size)) + + linear_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', linear_proj_weight) + + mlp_router_weight = state_dict['model.decoder.layers.mlp.router.weight'][layer] + yield (f'model.layers.{layer}.block_sparse_moe.gate.weight', mlp_router_weight) + + for expert in range(num_moe_experts): + linear_fc1_weight = state_dict['model.decoder.layers.mlp.experts.experts.linear_fc1.weight'][layer][ + expert + ] + gate_proj_weight, up_proj_weight = torch.chunk(linear_fc1_weight, 2, dim=0) + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w1.weight', gate_proj_weight) + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w3.weight', up_proj_weight) + + linear_fc2_weight = state_dict['model.decoder.layers.mlp.experts.experts.linear_fc2.weight'][layer][ + expert + ] + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w2.weight', linear_fc2_weight) + + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attn_layernorm_weight = state_dict['model.decoder.layers.pre_mlp_layernorm.weight'][layer] + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attn_layernorm_weight) + + def requires_bos_token(self): + return True + + +class GemmaConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'gemma': + return 'GemmaForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + num_layers = nemo_model_config["num_layers"] + num_query_groups = nemo_model_config["num_query_groups"] + head_num = nemo_model_config["num_attention_heads"] + head_size = nemo_model_config["kv_channels"] + hidden_size = nemo_model_config["hidden_size"] + heads_per_group = head_num // num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + + final_layernorm_weight = state_dict['model.decoder.final_layernorm.weight'] + final_layernorm_weight -= 1.0 + yield ('model.norm.weight', final_layernorm_weight) + + for layer in range(int(num_layers)): + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + input_layernorm_weight -= 1.0 + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attention_layernorm_weight = state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][ + layer + ] + post_attention_layernorm_weight -= 1.0 + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attention_layernorm_weight) + + gate_up_combined_weight = state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer] + gate_size = gate_up_combined_weight.shape[0] // 2 + yield (f'model.layers.{layer}.mlp.gate_proj.weight', gate_up_combined_weight[:gate_size, :]) + yield (f'model.layers.{layer}.mlp.up_proj.weight', gate_up_combined_weight[gate_size:, :]) + + down_proj_weight = state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer] + yield (f'model.layers.{layer}.mlp.down_proj.weight', down_proj_weight) + + self_attn_o_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', self_attn_o_proj_weight) + + qkv_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_intermediate_size = head_num + 2 * num_query_groups + qkv_weight = qkv_weight.reshape(qkv_intermediate_size, head_size, hidden_size) + + q_weight = torch.empty((head_num, head_size, hidden_size), dtype=qkv_weight.dtype) + k_weight = torch.empty((num_query_groups, head_size, hidden_size), dtype=qkv_weight.dtype) + v_weight = torch.empty((num_query_groups, head_size, hidden_size), dtype=qkv_weight.dtype) + + ptr = 0 + for i in range(num_query_groups): + q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[ + ptr : ptr + heads_per_group, :: + ] + ptr += heads_per_group + k_weight[i : i + 1, :, :] = qkv_weight[ptr : ptr + 1, :, :] + ptr += 1 + v_weight[i : i + 1, :, :] = qkv_weight[ptr : ptr + 1, :, :] + ptr += 1 + assert ptr == qkv_intermediate_size + + q_weight = q_weight.reshape(head_num * head_size, hidden_size) + k_weight = k_weight.reshape(num_query_groups * head_size, hidden_size) + v_weight = v_weight.reshape(num_query_groups * head_size, hidden_size) + + yield (f'model.layers.{layer}.self_attn.q_proj.weight', q_weight) + yield (f'model.layers.{layer}.self_attn.k_proj.weight', k_weight) + yield (f'model.layers.{layer}.self_attn.v_proj.weight', v_weight) + + def requires_bos_token(self): + return True + + +class Starcoder2Converter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'starcoder2': + return 'Starcoder2ForCausalLM' + return None + + def convert_config(self, nemo_model_config, hf_config): + window_sizes = nemo_model_config.get('window_size') + if window_sizes is not None: + hf_config['sliding_window'] = window_sizes[0] + + # 'tie_word_embeddings = False' means that there is a 'lm_head.weight' tensor. + # This converter assumes that it's always there. + # If there is a version of starcoder2 where it's not there, we'll need to copy + # 'model.embed_tokens.weight' into 'lm_head.weight' and still set 'tie_word_embeddings = False' + # because at this point we don't know if the weight is there or not, and this configuration + # is not stored in NeMo checkpoints. + hf_config['tie_word_embeddings'] = False + + def convert_weights(self, nemo_model_config, state_dict): + num_layers = nemo_model_config["num_layers"] + num_query_groups = nemo_model_config["num_query_groups"] + head_num = nemo_model_config["num_attention_heads"] + hidden_size = nemo_model_config["hidden_size"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + has_bias = nemo_model_config["bias"] + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + if has_bias: + yield ('model.norm.bias', state_dict['model.decoder.final_layernorm.bias']) + + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + # q,k,v + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + if has_bias: + qkv_bias = state_dict['model.decoder.layers.self_attention.linear_qkv.bias'][layer] + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + qkv_weights_slice = qkv_weights[slice].reshape(-1, hidden_size) + yield (f'model.layers.{layer}.self_attn.{name}.weight', qkv_weights_slice) + if has_bias: + qkv_bias_slice = qkv_bias[slice].reshape(-1) + yield (f'model.layers.{layer}.self_attn.{name}.bias', qkv_bias_slice) + + # Attention dense + yield ( + f'model.layers.{layer}.self_attn.o_proj.weight', + state_dict[f'model.decoder.layers.self_attention.linear_proj.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.self_attn.o_proj.bias', + state_dict['model.decoder.layers.self_attention.linear_proj.bias'][layer], + ) + + # MLP FC1 + yield ( + f'model.layers.{layer}.mlp.c_fc.weight', + state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.mlp.c_fc.bias', + state_dict['model.decoder.layers.mlp.linear_fc1.bias'][layer], + ) + + # MLP FC2 + yield ( + f'model.layers.{layer}.mlp.c_proj.weight', + state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.mlp.c_proj.bias', + state_dict['model.decoder.layers.mlp.linear_fc2.bias'][layer], + ) + + # Input LayerNorm + yield ( + f'model.layers.{layer}.input_layernorm.weight', + state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.input_layernorm.bias', + state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_bias'][layer], + ) + + # Post-attention LayerNorm + yield ( + f'model.layers.{layer}.post_attention_layernorm.weight', + state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.post_attention_layernorm.bias', + state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_bias'][layer], + ) + + +_MODEL_CONVERTERS = { + 'llama': LlamaConverter, + 'mistral': LlamaConverter, + 'mixtral': MixtralConverter, + 'gemma': GemmaConverter, + 'starcoder2': Starcoder2Converter, +} + + +def register_model_converter(model_type, cls): + """ + Establishes a mapping from short model type to a class that converts the model from Nemo format + to a vLLM compatible format. + """ + _MODEL_CONVERTERS[model_type] = cls + + +def get_model_converter(model_type) -> ModelConverter: + """ + Returns an instance of the the model conversion class for the given model type, or None. + """ + cls = _MODEL_CONVERTERS.get(model_type, None) + if cls is None: + return None + return cls(model_type) diff --git a/nemo/export/vllm/model_loader.py b/nemo/export/vllm/model_loader.py new file mode 100644 index 000000000000..e7f3f1d1569f --- /dev/null +++ b/nemo/export/vllm/model_loader.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os.path +from typing import Optional + +import numpy +import safetensors.torch +import tensorstore # needed to register 'bfloat16' dtype with numpy for zarr compatibility +import torch +import zarr +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig +from vllm.model_executor.model_loader.loader import BaseModelLoader, _initialize_model +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +from nemo.export.tarutils import TarPath, ZarrPathStore +from nemo.export.vllm.model_config import NemoModelConfig + +LOGGER = logging.getLogger("NeMo") + + +class NemoModelLoader(BaseModelLoader): + """ + Implements a custom ModelLoader for vLLM that reads the weights from a Nemo checkpoint + and converts them to a vLLM compatible format at load time. + + Also supports an ahead-of-time conversion that stores new weights in a Safetensors file, + see convert_and_store_nemo_weights(...) + """ + + @staticmethod + def _load_nemo_checkpoint_state(nemo_file: str): + sharded_state_dict = {} + + LOGGER.info(f'Loading weights from {nemo_file}...') + + with TarPath(nemo_file) as archive: + for subdir in archive.iterdir(): + if not subdir.is_dir() or not (subdir / '.zarray').exists(): + continue + key = subdir.name + + zstore = ZarrPathStore(subdir) + arr = zarr.open(zstore, 'r') + + if arr.dtype.name == "bfloat16": + sharded_state_dict[key] = torch.from_numpy(arr[:].view(numpy.int16)).view(torch.bfloat16) + else: + sharded_state_dict[key] = torch.from_numpy(arr[:]) + + arr = None + gc.collect() + + LOGGER.debug(f'Loaded tensor "{key}": {sharded_state_dict[key].shape}') + + return sharded_state_dict + + def load_model( + self, + *, + model_config: NemoModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> torch.nn.Module: + """ + Overrides the load_model function from BaseModelLoader to convert Nemo weights at load time. + """ + + assert isinstance(model_config, NemoModelConfig) + state_dict = NemoModelLoader._load_nemo_checkpoint_state(model_config.nemo_checkpoint) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, self.load_config, lora_config, vision_language_config, cache_config + ) + + weights_iterator = model_config.model_converter.convert_weights(model_config.nemo_model_config, state_dict) + + model.load_weights(weights_iterator) + + return model.eval() + + @staticmethod + def convert_and_store_nemo_weights(model_config: NemoModelConfig, safetensors_file: str): + """ + Converts Nemo weights and stores the converted weights in a Safetensors file. + """ + + assert isinstance(model_config, NemoModelConfig) + assert os.path.exists(model_config.model) + + state_dict = NemoModelLoader._load_nemo_checkpoint_state(model_config.nemo_checkpoint) + + tensors = { + name: tensor + for name, tensor in model_config.model_converter.convert_weights( + model_config.nemo_model_config, state_dict + ) + } + + LOGGER.info(f'Saving weights to {safetensors_file}...') + safetensors.torch.save_file(tensors, safetensors_file) diff --git a/nemo/export/vllm/tokenizer_group.py b/nemo/export/vllm/tokenizer_group.py new file mode 100644 index 000000000000..6e4aedc14acb --- /dev/null +++ b/nemo/export/vllm/tokenizer_group.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup + +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer + + +class NemoTokenizerGroup(BaseTokenizerGroup): + """ + Implements a custom tokenizer for vLLM, based on SentencePieceTokenizer. + """ + + def __init__(self, tokenizer: SentencePieceTokenizer, add_bos_token: bool = False): + self.tokenizer = tokenizer + self.add_bos_token = add_bos_token + + def ping(self) -> bool: + return True + + def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: + return None + + def encode( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: + ids = self.tokenizer.encode(prompt) + if self.add_bos_token: + ids = [self.tokenizer.bos_token_id] + ids + return ids + + async def encode_async( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: + return self.tokenizer.encode(prompt) # TODO: not sure how this is supposed to work + + def get_lora_tokenizer(self, lora_request: Optional[LoRARequest] = None) -> SentencePieceTokenizer: + return self.tokenizer + + async def get_lora_tokenizer_async(self, lora_request: Optional[LoRARequest] = None) -> SentencePieceTokenizer: + return self.tokenizer diff --git a/nemo/export/vllm_exporter.py b/nemo/export/vllm_exporter.py new file mode 100644 index 000000000000..f3dd6c8a248b --- /dev/null +++ b/nemo/export/vllm_exporter.py @@ -0,0 +1,417 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os.path +from typing import Iterable, List, Optional, Union + +import numpy +import wrapt +from vllm import RequestOutput, SamplingParams +from vllm.config import CacheConfig, DeviceConfig, LoadConfig, LoadFormat, ParallelConfig, SchedulerConfig +from vllm.executor.ray_utils import initialize_ray_cluster + +from nemo.deploy import ITritonDeployable +from nemo.deploy.utils import cast_output +from nemo.export.vllm.engine import NemoLLMEngine +from nemo.export.vllm.model_config import NemoModelConfig +from nemo.export.vllm.model_loader import NemoModelLoader + +LOGGER = logging.getLogger("NeMo") + + +@wrapt.decorator +def noop_decorator(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +use_pytriton = True +try: + from pytriton.decorators import batch + from pytriton.model_config import Tensor +except Exception: + use_pytriton = False + + +class vLLMExporter(ITritonDeployable): + """ + The Exporter class implements conversion from a Nemo checkpoint format to something compatible with vLLM, + loading the model in vLLM, and binding that model to a Triton server. + + Example: + from nemo.export.vllm import Exporter + from nemo.deploy import DeployPyTriton + + exporter = Exporter() + exporter.export( + nemo_checkpoint='/path/to/checkpoint.nemo', + model_dir='/path/to/temp_dir', + model_type='llama') + + server = DeployPyTriton( + model=exporter, + triton_model_name='LLAMA') + + server.deploy() + server.serve() + server.stop() + """ + + def __init__(self): + self.request_id = 0 + + def export( + self, + nemo_checkpoint: str, + model_dir: str, + model_type: str, + device: str = 'auto', + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: int = None, + dtype: str = 'auto', + seed: int = 0, + log_stats: bool = True, + weight_storage: str = 'auto', + gpu_memory_utilization: float = 0.9, + ): + """ + Exports the Nemo checkpoint to vLLM and initializes the engine. + + Args: + nemo_checkpoint (str): path to the nemo checkpoint. + model_dir (str): path to a temporary directory to store weights and the tokenizer model. + The temp dir may persist between subsequent export operations, in which case + converted weights may be reused to speed up the export. + model_type (str): type of the model, such as "llama", "mistral", "mixtral". + Needs to be compatible with transformers.AutoConfig. + device (str): type of the device to use by the vLLM engine. + Supported values are "auto", "cuda", "cpu", "neuron". + tensor_parallel_size (int): tensor parallelism. + pipeline_parallel_size (int): pipeline parallelism. + Values over 1 are not currently supported by vLLM. + max_model_len (int): model context length. + dtype (str): data type for model weights and activations. + Possible choices: auto, half, float16, bfloat16, float, float32 + "auto" will use FP16 precision for FP32 and FP16 models, + and BF16 precision for BF16 models. + seed (int): random seed value. + log_stats (bool): enables logging inference performance statistics by vLLM. + weight_storage (str): controls how converted weights are stored: + "file" - always write weights into a file inside 'model_dir', + "memory" - always do an in-memory conversion, + "cache" - reuse existing files if they are newer than the nemo checkpoint, + "auto" - use "cache" for multi-GPU runs and "memory" for single-GPU runs. + gpu_memory_utilization (float): The fraction of GPU memory to be used for the model + executor, which can range from 0 to 1. + """ + + # Pouplate the basic configuration structures + device_config = DeviceConfig(device) + + model_config = NemoModelConfig( + nemo_checkpoint, + model_dir, + model_type, + tokenizer_mode='auto', + dtype=dtype, + seed=seed, + revision=None, + code_revision=None, + tokenizer_revision=None, + max_model_len=max_model_len, + quantization=None, # TODO ??? + quantization_param_path=None, + enforce_eager=False, + max_seq_len_to_capture=None, + ) + + parallel_config = ParallelConfig( + pipeline_parallel_size=pipeline_parallel_size, tensor_parallel_size=tensor_parallel_size + ) + + # See if we have an up-to-date safetensors file + safetensors_file = os.path.join(model_config.model, 'model.safetensors') + safetensors_file_valid = os.path.exists(safetensors_file) and os.path.getmtime( + safetensors_file + ) > os.path.getmtime(nemo_checkpoint) + + # Decide how we're going to convert the weights + if weight_storage == 'auto': + if parallel_config.distributed_executor_backend is not None: + save_weights = not safetensors_file_valid + inmemory_weight_conversion = False + else: + save_weights = False + inmemory_weight_conversion = True + + elif weight_storage == 'cache': + save_weights = not safetensors_file_valid + inmemory_weight_conversion = False + + elif weight_storage == 'file': + save_weights = True + inmemory_weight_conversion = False + + elif weight_storage == 'memory': + save_weights = False + inmemory_weight_conversion = True + + else: + raise ValueError(f'Unsupported value for weight_storage: "{weight_storage}"') + + # Convert the weights ahead-of-time, if needed + if save_weights: + NemoModelLoader.convert_and_store_nemo_weights(model_config, safetensors_file) + elif not inmemory_weight_conversion: + LOGGER.info(f'Using cached weights in {safetensors_file}') + + # TODO: these values are the defaults from vllm.EngineArgs. + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=4, + cache_dtype='auto', + sliding_window=model_config.get_sliding_window(), + ) + + # TODO: these values are the defaults from vllm.EngineArgs. + scheduler_config = SchedulerConfig( + max_num_batched_tokens=None, + max_num_seqs=256, + # Note: max_model_len can be derived by model_config if the input value is None + max_model_len=model_config.max_model_len, + use_v2_block_manager=False, + num_lookahead_slots=0, + delay_factor=0.0, + enable_chunked_prefill=False, + ) + + load_config = LoadConfig( + load_format=NemoModelLoader if inmemory_weight_conversion else LoadFormat.SAFETENSORS, + download_dir=None, + model_loader_extra_config=None, + ) + + # Initialize the cluster and specify the executor class. + if device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutor + + executor_class = NeuronExecutor + elif device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + + executor_class = CPUExecutor + elif parallel_config.distributed_executor_backend == "ray": + initialize_ray_cluster(parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutor + + executor_class = RayGPUExecutor + elif parallel_config.distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor + + executor_class = MultiprocessingGPUExecutor + else: + assert parallel_config.world_size == 1, "Ray is required if parallel_config.world_size > 1." + from vllm.executor.gpu_executor import GPUExecutor + + executor_class = GPUExecutor + + # Initialize the engine + self.engine = NemoLLMEngine( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + lora_config=None, + vision_language_config=None, + speculative_config=None, + decoding_config=None, + executor_class=executor_class, + log_stats=log_stats, + ) + + def _add_request_to_engine( + self, prompt: str, max_output_len: int, temperature: float = 1.0, top_k: int = 1, top_p: float = 0.0 + ) -> str: + if top_p <= 0.0: + top_p = 1.0 + + sampling_params = SamplingParams(max_tokens=max_output_len, temperature=temperature, top_k=top_k, top_p=top_p) + + request_id = str(self.request_id) + self.request_id += 1 + + self.engine.add_request(request_id, prompt, sampling_params) + + return request_id + + def _forward_regular(self, request_ids: List[str]): + responses = [None] * len(request_ids) + finished = [False] * len(request_ids) + + while not all(finished): + request_outputs: List[RequestOutput] = self.engine.step() + + for request_output in request_outputs: + if not request_output.finished: + continue + + try: + request_index = request_ids.index(request_output.request_id) + except ValueError: + continue + + finished[request_index] = request_output.finished + output_text = request_output.outputs[-1].text + responses[request_index] = output_text + + return [[response] for response in responses] + + def _forward_streaming(self, request_ids: List[str]): + responses = [None] * len(request_ids) + finished = [False] * len(request_ids) + + while not all(finished): + request_outputs: List[RequestOutput] = self.engine.step() + + for request_output in request_outputs: + try: + request_index = request_ids.index(request_output.request_id) + except ValueError: + continue + + finished[request_index] = request_output.finished + output_text = request_output.outputs[-1].text + responses[request_index] = output_text + + yield [[response] for response in responses] + + def _add_triton_request_to_engine(self, inputs: numpy.ndarray, index: int) -> str: + return self._add_request_to_engine( + prompt=inputs['prompts'][index][0].decode('UTF-8'), + max_output_len=inputs['max_output_len'][index][0], + temperature=inputs['temperature'][index][0], + top_k=inputs['top_k'][index][0], + top_p=inputs['top_p'][index][0], + ) + + @property + def get_triton_input(self): + inputs = ( + Tensor(name="prompts", shape=(-1,), dtype=bytes), + Tensor(name="max_output_len", shape=(-1,), dtype=numpy.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=numpy.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=numpy.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=numpy.single, optional=True), + ) + return inputs + + @property + def get_triton_output(self): + outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),) + return outputs + + @batch + def triton_infer_fn(self, **inputs: numpy.ndarray): + request_ids = [] + num_requests = len(inputs["prompts"]) + for index in range(num_requests): + request_id = self._add_triton_request_to_engine(inputs, index) + request_ids.append(request_id) + + responses = self._forward_regular(request_ids) + responses = [r[0] for r in responses] + + output_tensor = cast_output(responses, numpy.bytes_) + return {'outputs': output_tensor} + + @batch + def triton_infer_fn_streaming(self, **inputs: numpy.ndarray): + request_ids = [] + num_requests = len(inputs["prompts"]) + for index in range(num_requests): + request_id = self._add_triton_request_to_engine(inputs, index) + request_ids.append(request_id) + + for responses in self._forward_streaming(request_ids): + responses = [r[0] for r in responses] + output_tensor = cast_output(responses, numpy.bytes_) + yield {'outputs': output_tensor} + + # Mimic the TensorRTLLM exporter's forward function, even though we don't support many of its features. + def forward( + self, + input_texts: List[str], + max_output_len: int = 64, + top_k: int = 1, + top_p: float = 0.0, + temperature: float = 1.0, + stop_words_list: Optional[List[str]] = None, + bad_words_list: Optional[List[str]] = None, + no_repeat_ngram_size: Optional[int] = None, + task_ids: Optional[List[str]] = None, + lora_uids: Optional[List[str]] = None, + prompt_embeddings_table=None, + prompt_embeddings_checkpoint_path: Optional[str] = None, + streaming: bool = False, + output_log_probs: bool = False, + ) -> Union[List[List[str]], Iterable[List[List[str]]]]: + """ + The forward function performs LLM evaluation on the provided array of prompts with other parameters shared, + and returns the generated texts. If 'streaming' is True, the output texts are returned incrementally + with a generator: one token appended to each output at a time. If 'streaming' is false, the final output texts + are returned as a single list of responses. + """ + + if stop_words_list is not None and stop_words_list != []: + raise NotImplementedError("stop_words_list is not supported") + + if bad_words_list is not None and bad_words_list != []: + raise NotImplementedError("bad_words_list is not supported") + + if no_repeat_ngram_size is not None: + raise NotImplementedError("no_repeat_ngram_size is not supported") + + if task_ids is not None and task_ids != []: + raise NotImplementedError("task_ids is not supported") + + if lora_uids is not None and lora_uids != []: + raise NotImplementedError("lora_uids is not supported") + + if prompt_embeddings_table is not None: + raise NotImplementedError("prompt_embeddings_table is not supported") + + if prompt_embeddings_checkpoint_path is not None: + raise NotImplementedError("prompt_embeddings_checkpoint_path is not supported") + + if output_log_probs: + raise NotImplementedError("output_log_probs is not supported") + + request_ids = [] + for prompt in input_texts: + request_id = self._add_request_to_engine( + prompt=prompt, max_output_len=max_output_len, temperature=temperature, top_k=top_k, top_p=top_p + ) + request_ids.append(request_id) + + if streaming: + return self._forward_streaming(request_ids) + else: + return self._forward_regular(request_ids) diff --git a/requirements/requirements_vllm.txt b/requirements/requirements_vllm.txt new file mode 100644 index 000000000000..a603b3c4ec53 --- /dev/null +++ b/requirements/requirements_vllm.txt @@ -0,0 +1 @@ +vllm==0.5.0 diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index d0854916cd38..8916fec0b1dd 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -16,14 +16,34 @@ import logging import os import sys +import tempfile from pathlib import Path from nemo.deploy import DeployPyTriton -from nemo.deploy.nlp import MegatronLLMDeployable -from nemo.export import TensorRTLLM LOGGER = logging.getLogger("NeMo") +megatron_llm_supported = True +try: + from nemo.deploy.nlp import MegatronLLMDeployable +except Exception as e: + LOGGER.warning(f"Cannot import MegatronLLMDeployable, it will not be available. {type(e).__name__}: {e}") + megatron_llm_supported = False + +trt_llm_supported = True +try: + from nemo.export.tensorrt_llm import TensorRTLLM +except Exception as e: + LOGGER.warning(f"Cannot import the TensorRTLLM exporter, it will not be available. {type(e).__name__}: {e}") + trt_llm_supported = False + +vllm_supported = True +try: + from nemo.export.vllm_exporter import vLLMExporter +except Exception as e: + LOGGER.warning(f"Cannot import the vLLM exporter, it will not be available. {type(e).__name__}: {e}") + vllm_supported = False + def get_args(argv): parser = argparse.ArgumentParser( @@ -69,7 +89,7 @@ def get_args(argv): choices=["bfloat16", "float16", "fp8", "int8"], default="bfloat16", type=str, - help="dtype of the model on TensorRT-LLM", + help="dtype of the model on TensorRT-LLM or vLLM", ) parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model") parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model") @@ -150,7 +170,23 @@ def get_args(argv): help="Different options to deploy nemo model.", ) parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") - + parser.add_argument( + '-ws', + '--weight_storage', + default='auto', + choices=['auto', 'cache', 'file', 'memory'], + help='Strategy for storing converted weights for vLLM: "file" - always write weights into a file, ' + '"memory" - always do an in-memory conversion, "cache" - reuse existing files if they are ' + 'newer than the nemo checkpoint, "auto" - use "cache" for multi-GPU runs and "memory" ' + 'for single-GPU runs.', + ) + parser.add_argument( + "-gmu", + '--gpu_memory_utilization', + default=0.9, + type=float, + help="GPU memory utilization percentage for vLLM.", + ) args = parser.parse_args(argv) return args @@ -160,8 +196,8 @@ def get_trtllm_deployable(args): trt_llm_path = "/tmp/trt_llm_model_dir/" LOGGER.info( "/tmp/trt_llm_model_dir/ path will be used as the TensorRT LLM folder. " - "Please set this parameter if you'd like to use a path that has already " - "included the TensorRT LLM model files." + "Please set the --triton_model_repository parameter if you'd like to use a path that already " + "includes the TensorRT LLM model files." ) Path(trt_llm_path).mkdir(parents=True, exist_ok=True) else: @@ -261,6 +297,45 @@ def get_trtllm_deployable(args): return trt_llm_exporter +def get_vllm_deployable(args): + if args.ptuning_nemo_checkpoint is not None: + raise ValueError("vLLM backend doesn't support P-tuning at this time.") + if args.lora_ckpt is not None: + raise ValueError("vLLM backend doesn't support LoRA at this time.") + + tempdir = None + model_dir = args.triton_model_repository + if model_dir is None: + tempdir = tempfile.TemporaryDirectory() + model_dir = tempdir.name + LOGGER.info( + f"{model_dir} path will be used as the vLLM intermediate folder. " + + "Please set the --triton_model_repository parameter if you'd like to use a path that already " + + "includes the vLLM model files." + ) + elif not os.path.exists(model_dir): + os.makedirs(model_dir) + + try: + exporter = vLLMExporter() + exporter.export( + nemo_checkpoint=args.nemo_checkpoint, + model_dir=model_dir, + model_type=args.model_type, + tensor_parallel_size=args.num_gpus, + max_model_len=args.max_input_len + args.max_output_len, + dtype=args.dtype, + weight_storage=args.weight_storage, + gpu_memory_utilization=args.gpu_memory_utilization, + ) + return exporter + except Exception as error: + raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) + finally: + if tempdir is not None: + tempdir.cleanup() + + def get_nemo_deployable(args): if args.nemo_checkpoint is None: raise ValueError("In-Framework deployment requires a .nemo checkpoint") @@ -282,11 +357,17 @@ def nemo_deploy(argv): backend = args.backend.lower() if backend == 'tensorrt-llm': + if not trt_llm_supported: + raise ValueError("TensorRT-LLM engine is not supported in this environment.") triton_deployable = get_trtllm_deployable(args) elif backend == 'in-framework': + if not megatron_llm_supported: + raise ValueError("MegatronLLMDeployable is not supported in this environment.") triton_deployable = get_nemo_deployable(args) elif backend == 'vllm': - raise ValueError("vLLM will be supported in the next release.") + if not vllm_supported: + raise ValueError("vLLM engine is not supported in this environment.") + triton_deployable = get_vllm_deployable(args) else: raise ValueError("Backend: {0} is not supported.".format(backend)) diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index a0c70c8bbd85..49fefd40561b 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -16,7 +16,7 @@ import logging import sys -from nemo.export import TensorRTLLM +from nemo.export.tensorrt_llm import TensorRTLLM LOGGER = logging.getLogger("NeMo") diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 5541cc0f8673..013a22deee3b 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -14,46 +14,85 @@ import argparse import json +import logging import shutil +import sys import time +from dataclasses import dataclass from pathlib import Path +from typing import Dict, List, Optional, Tuple + import torch -from tests.infer_data_path import get_infer_test_data +# Import infer_data_path from the parent folder assuming that the 'tests' package is not installed. +sys.path.append(str(Path(__file__).parent.parent)) +from infer_data_path import get_infer_test_data + +LOGGER = logging.getLogger("NeMo") -run_export_tests = True +triton_supported = True try: from nemo.deploy import DeployPyTriton from nemo.deploy.nlp import NemoQueryLLM - from nemo.export import TensorRTLLM except Exception as e: - run_export_tests = False + LOGGER.warning(f"Cannot import Triton, deployment will not be available. {type(e).__name__}: {e}") + triton_supported = False + +trt_llm_supported = True +try: + from nemo.export.tensorrt_llm import TensorRTLLM +except Exception as e: + LOGGER.warning(f"Cannot import the TensorRTLLM exporter, it will not be available. {type(e).__name__}: {e}") + trt_llm_supported = False + +vllm_supported = True +try: + from nemo.export.vllm_exporter import vLLMExporter +except Exception as e: + LOGGER.warning(f"Cannot import the vLLM exporter, it will not be available. {type(e).__name__}: {e}") + vllm_supported = False -def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=None): +class UsageError(Exception): + pass + + +@dataclass +class FunctionalResult: + regular_pass: Optional[bool] = None + deployed_pass: Optional[bool] = None + + +@dataclass +class AccuracyResult: + accuracy: float + accuracy_relaxed: float + deployed_accuracy: float + deployed_accuracy_relaxed: float + evaluation_time: float + + +def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path): # lambada dataset based accuracy test, which includes more than 5000 sentences. # Use generated last token with original text's last token for accuracy comparison. # If the generated last token start with the original token, trtllm_correct make an increment. # It generates a CSV file for text comparison detail. - if test_data_path is None: - raise Exception("test_data_path cannot be None.") - - trtllm_correct = 0 - trtllm_deployed_correct = 0 - trtllm_correct_relaxed = 0 - trtllm_deployed_correct_relaxed = 0 + correct_answers = 0 + correct_answers_deployed = 0 + correct_answers_relaxed = 0 + correct_answers_deployed_relaxed = 0 all_expected_outputs = [] - all_trtllm_outputs = [] + all_actual_outputs = [] with open(test_data_path, 'r') as file: records = json.load(file) - eval_start = time.perf_counter() + eval_start = time.monotonic() for record in records: prompt = record["text_before_last_word"] expected_output = record["last_word"].strip().lower() - trtllm_output = model.forward( + model_output = model.forward( input_texts=[prompt], max_output_len=1, top_k=1, @@ -62,22 +101,22 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=Non task_ids=task_ids, lora_uids=lora_uids, ) - trtllm_output = trtllm_output[0][0].strip().lower() + model_output = model_output[0][0].strip().lower() all_expected_outputs.append(expected_output) - all_trtllm_outputs.append(trtllm_output) + all_actual_outputs.append(model_output) - if expected_output == trtllm_output: - trtllm_correct += 1 + if expected_output == model_output: + correct_answers += 1 if ( - expected_output == trtllm_output - or trtllm_output.startswith(expected_output) - or expected_output.startswith(trtllm_output) + expected_output == model_output + or model_output.startswith(expected_output) + or expected_output.startswith(model_output) ): - if len(trtllm_output) == 1 and len(expected_output) > 1: + if len(model_output) == 1 and len(expected_output) > 1: continue - trtllm_correct_relaxed += 1 + correct_answers_relaxed += 1 if nq is not None: trtllm_deployed_output = nq.query_llm( @@ -91,7 +130,7 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=Non trtllm_deployed_output = trtllm_deployed_output[0][0].strip().lower() if expected_output == trtllm_deployed_output: - trtllm_deployed_correct += 1 + correct_answers_deployed += 1 if ( expected_output == trtllm_deployed_output @@ -100,32 +139,47 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=Non ): if len(trtllm_deployed_output) == 1 and len(expected_output) > 1: continue - trtllm_deployed_correct_relaxed += 1 - eval_end = time.perf_counter() + correct_answers_deployed_relaxed += 1 + eval_end = time.monotonic() + + return AccuracyResult( + accuracy=correct_answers / len(all_expected_outputs), + accuracy_relaxed=correct_answers_relaxed / len(all_expected_outputs), + deployed_accuracy=correct_answers_deployed / len(all_expected_outputs), + deployed_accuracy_relaxed=correct_answers_deployed_relaxed / len(all_expected_outputs), + evaluation_time=eval_end - eval_start, + ) - trtllm_accuracy = trtllm_correct / len(all_expected_outputs) - trtllm_accuracy_relaxed = trtllm_correct_relaxed / len(all_expected_outputs) - trtllm_deployed_accuracy = trtllm_deployed_correct / len(all_expected_outputs) - trtllm_deployed_accuracy_relaxed = trtllm_deployed_correct_relaxed / len(all_expected_outputs) +# Tests if the model outputs contain the expected keywords. +def check_model_outputs(streaming: bool, model_outputs, expected_outputs: List[str]) -> bool: - evaluation_time = eval_end - eval_start + # In streaming mode, we get a list of lists of lists, and we only care about the last item in that list + if streaming: + if len(model_outputs) == 0: + return False + model_outputs = model_outputs[-1] - return ( - trtllm_accuracy, - trtllm_accuracy_relaxed, - trtllm_deployed_accuracy, - trtllm_deployed_accuracy_relaxed, - evaluation_time, - ) + # See if we have the right number of final answers. + if len(model_outputs) != len(expected_outputs): + return False + + # Check the presence of keywords in the final answers. + for i in range(len(model_outputs)): + if expected_outputs[i] not in model_outputs[i][0]: + return False + return True -def run_trt_llm_inference( + +def run_inference( model_name, model_type, - prompt, + prompts, + expected_outputs, checkpoint_path, - trt_llm_model_dir, + model_dir, + use_vllm, n_gpu=1, max_batch_size=8, use_embedding_sharing=False, @@ -135,8 +189,8 @@ def run_trt_llm_inference( p_tuning_checkpoint=None, lora=False, lora_checkpoint=None, - tp_size=None, - pp_size=None, + tp_size=1, + pp_size=1, top_k=1, top_p=0.0, temperature=1.0, @@ -147,7 +201,7 @@ def run_trt_llm_inference( test_deployment=False, test_data_path=None, save_trt_engine=False, -): +) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if Path(checkpoint_path).exists(): if n_gpu > torch.cuda.device_count(): print( @@ -155,9 +209,9 @@ def run_trt_llm_inference( checkpoint_path, model_name, n_gpu, torch.cuda.device_count() ) ) - return None, None, None, None, None + return (None, None) - Path(trt_llm_model_dir).mkdir(parents=True, exist_ok=True) + Path(model_dir).mkdir(parents=True, exist_ok=True) if debug: print("") @@ -182,7 +236,7 @@ def run_trt_llm_inference( print("---- PTuning enabled.") else: print("---- PTuning could not be enabled and skipping the test.") - return None, None, None, None, None + return (None, None) lora_ckpt_list = None lora_uids = None @@ -199,36 +253,48 @@ def run_trt_llm_inference( print("---- LoRA enabled.") else: print("---- LoRA could not be enabled and skipping the test.") - return None, None, None, None, None - - trt_llm_exporter = TensorRTLLM(trt_llm_model_dir, lora_ckpt_list, load_model=False) - - trt_llm_exporter.export( - nemo_checkpoint_path=checkpoint_path, - model_type=model_type, - n_gpus=n_gpu, - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - max_prompt_embedding_table_size=max_prompt_embedding_table_size, - use_lora_plugin=use_lora_plugin, - lora_target_modules=lora_target_modules, - max_num_tokens=int(max_input_len * max_batch_size * 0.2), - opt_num_tokens=60, - use_embedding_sharing=use_embedding_sharing, - save_nemo_model_config=True, - ) + return (None, None) + + if use_vllm: + exporter = vLLMExporter() + + exporter.export( + nemo_checkpoint=checkpoint_path, + model_dir=model_dir, + model_type=model_type, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + max_model_len=max_input_len + max_output_len, + ) + else: + exporter = TensorRTLLM(model_dir, lora_ckpt_list, load_model=False) + + exporter.export( + nemo_checkpoint_path=checkpoint_path, + model_type=model_type, + n_gpus=n_gpu, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + max_prompt_embedding_table_size=max_prompt_embedding_table_size, + use_lora_plugin=use_lora_plugin, + lora_target_modules=lora_target_modules, + max_num_tokens=int(max_input_len * max_batch_size * 0.2), + opt_num_tokens=60, + use_embedding_sharing=use_embedding_sharing, + save_nemo_model_config=True, + ) if ptuning: - trt_llm_exporter.add_prompt_table( + exporter.add_prompt_table( task_name="0", prompt_embeddings_checkpoint_path=prompt_embeddings_checkpoint_path, ) - output = trt_llm_exporter.forward( - input_texts=prompt, + output = exporter.forward( + input_texts=prompts, max_output_len=max_output_len, top_k=top_k, top_p=top_p, @@ -239,10 +305,21 @@ def run_trt_llm_inference( stop_words_list=stop_words_list, ) - if not use_lora_plugin and not ptuning: + # Unwrap the generator if needed + output = list(output) + + functional_result = FunctionalResult() + + # Check non-deployed funcitonal correctness + functional_result.regular_pass = True + if not check_model_outputs(streaming, output, expected_outputs): + LOGGER.warning("Model outputs don't match the expected result.") + functional_result.regular_pass = False + + if not use_lora_plugin and not ptuning and not use_vllm: test_cpp_runtime( - engine_path=trt_llm_model_dir, - prompt=prompt, + engine_path=model_dir, + prompt=prompts, max_output_len=max_output_len, debug=True, ) @@ -252,7 +329,7 @@ def run_trt_llm_inference( output_deployed = "" if test_deployment: nm = DeployPyTriton( - model=trt_llm_exporter, + model=exporter, triton_model_name=model_name, port=8000, ) @@ -261,7 +338,7 @@ def run_trt_llm_inference( nq = NemoQueryLLM(url="localhost:8000", model_name=model_name) output_deployed = nq.query_llm( - prompts=prompt, + prompts=prompts, max_output_len=max_output_len, top_k=1, top_p=0.0, @@ -269,33 +346,38 @@ def run_trt_llm_inference( lora_uids=lora_uids, ) - if debug: + # Unwrap the generator if needed + output_deployed = list(output_deployed) + + # Check deployed funcitonal correctness + functional_result.deployed_pass = True + if not check_model_outputs(streaming, output_deployed, expected_outputs): + LOGGER.warning("Deployed model outputs don't match the expected result.") + functional_result.deployed_pass = False + + if debug or functional_result.regular_pass == False or functional_result.deployed_pass == False: print("") - print("--- Prompt: ", prompt) + print("--- Prompt: ", prompts) print("") - print("--- Output: ", output) + print("--- Expected keywords: ", expected_outputs) print("") + print("--- Output: ", output) print("") print("--- Output deployed: ", output_deployed) print("") + accuracy_result = None if run_accuracy: print("Start model accuracy testing ...") - result = get_accuracy_with_lambada(trt_llm_exporter, nq, task_ids, lora_uids, test_data_path) - if test_deployment: - nm.stop() - - if not save_trt_engine: - shutil.rmtree(trt_llm_model_dir) - return result + accuracy_result = get_accuracy_with_lambada(exporter, nq, task_ids, lora_uids, test_data_path) if test_deployment: nm.stop() if not save_trt_engine: - shutil.rmtree(trt_llm_model_dir) + shutil.rmtree(model_dir) - return None, None, None, None, None + return (functional_result, accuracy_result) else: raise Exception("Checkpoint {0} could not be found.".format(checkpoint_path)) @@ -323,6 +405,7 @@ def test_cpp_runtime( def run_existing_checkpoints( model_name, + use_vllm, n_gpus, tp_size=None, pp_size=None, @@ -334,10 +417,10 @@ def run_existing_checkpoints( stop_words_list=None, test_data_path=None, save_trt_engine=False, -): +) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if n_gpus > torch.cuda.device_count(): print("Skipping the test due to not enough number of GPUs") - return None, None, None, None, None + return (None, None) test_data = get_infer_test_data() if not (model_name in test_data.keys()): @@ -347,7 +430,7 @@ def run_existing_checkpoints( if n_gpus < model_info["min_gpus"]: print("Min n_gpus for this model is {0}".format(n_gpus)) - return None, None, None, None, None + return (None, None) p_tuning_checkpoint = None if ptuning: @@ -369,12 +452,13 @@ def run_existing_checkpoints( else: use_embedding_sharing = False - return run_trt_llm_inference( + return run_inference( model_name=model_name, model_type=model_info["model_type"], - prompt=model_info["prompt_template"], + prompts=model_info["prompt_template"], checkpoint_path=model_info["checkpoint"], - trt_llm_model_dir=model_info["trt_llm_model_dir"], + model_dir=model_info["model_dir"], + use_vllm=use_vllm, n_gpu=n_gpus, max_batch_size=model_info["max_batch_size"], use_embedding_sharing=use_embedding_sharing, @@ -437,7 +521,7 @@ def get_args(): required=False, ) parser.add_argument( - "--trt_llm_model_dir", + "--model_dir", type=str, ) parser.add_argument( @@ -475,10 +559,12 @@ def get_args(): ) parser.add_argument( "--tp_size", + default=1, type=int, ) parser.add_argument( "--pp_size", + default=1, type=int, ) parser.add_argument( @@ -527,31 +613,48 @@ def get_args(): type=str, default="False", ) + parser.add_argument( + "--use_vllm", + type=str, + default="False", + ) + + args = parser.parse_args() + + def str_to_bool(name: str, s: str) -> bool: + true_strings = ["true", "1"] + false_strings = ["false", "0"] + if s.lower() in true_strings: + return True + if s.lower() in false_strings: + return False + raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") + + args.test_deployment = str_to_bool("test_deployment", args.test_deployment) + args.save_trt_engine = str_to_bool("save_trt_engin", args.save_trt_engine) + args.run_accuracy = str_to_bool("run_accuracy", args.run_accuracy) + args.use_vllm = str_to_bool("use_vllm", args.use_vllm) - return parser.parse_args() + return args def run_inference_tests(args): - if args.test_deployment == "True": - args.test_deployment = True - else: - args.test_deployment = False + if not args.use_vllm and not trt_llm_supported: + raise UsageError("TensorRT-LLM engine is not supported in this environment.") - if args.save_trt_engine == "True": - args.save_trt_engine = True - else: - args.save_trt_engine = False + if args.use_vllm and not vllm_supported: + raise UsageError("vLLM engine is not supported in this environment.") - if args.run_accuracy == "True": - args.run_accuracy = True - else: - args.run_accuracy = False + if args.use_vllm and (args.ptuning or args.lora): + raise UsageError("The vLLM integration currently does not support P-tuning or LoRA.") - if args.run_accuracy: - if args.test_data_path is None: - raise Exception("test_data_path param cannot be None.") + if args.test_deployment and not triton_supported: + raise UsageError("Deployment tests are not available because Triton is not supported in this environment.") - result_dic = {} + if args.run_accuracy and args.test_data_path is None: + raise UsageError("Accuracy testing requires the --test_data_path argument.") + + result_dic: Dict[int, Tuple[FunctionalResult, Optional[AccuracyResult]]] = {} if args.existing_test_models: n_gpus = args.min_gpus @@ -561,6 +664,7 @@ def run_inference_tests(args): while n_gpus <= args.max_gpus: result_dic[n_gpus] = run_existing_checkpoints( model_name=args.model_name, + use_vllm=args.use_vllm, n_gpus=n_gpus, ptuning=args.ptuning, lora=args.lora, @@ -575,18 +679,24 @@ def run_inference_tests(args): n_gpus = n_gpus * 2 else: - prompt_template = ["The capital of France is", "Largest animal in the sea is"] + if args.model_dir is None: + raise Exception("When using custom checkpoints, --model_dir is required.") + + prompts = ["The capital of France is", "Largest animal in the sea is"] + expected_outputs = ["Paris", "blue whale"] n_gpus = args.min_gpus if args.max_gpus is None: args.max_gpus = args.min_gpus while n_gpus <= args.max_gpus: - result_dic[n_gpus] = run_trt_llm_inference( + result_dic[n_gpus] = run_inference( model_name=args.model_name, model_type=args.model_type, - prompt=prompt_template, + prompts=prompts, + expected_outputs=expected_outputs, checkpoint_path=args.checkpoint_dir, - trt_llm_model_dir=args.trt_llm_model_dir, + model_dir=args.model_dir, + use_vllm=args.use_vllm, n_gpu=n_gpus, max_batch_size=args.max_batch_size, max_input_len=args.max_input_len, @@ -610,31 +720,59 @@ def run_inference_tests(args): n_gpus = n_gpus * 2 - test_result = "PASS" + functional_test_result = "PASS" + accuracy_test_result = "PASS" print_separator = False print("============= Test Summary ============") - for i, results in result_dic.items(): - if not results[0] is None and not results[1] is None: - if print_separator: - print("---------------------------------------") - print( - "Number of GPUS: {}\n" - "Model Accuracy: {:.4f}\n" - "Relaxed Model Accuracy: {:.4f}\n" - "Deployed Model Accuracy: {:.4f}\n" - "Deployed Relaxed Model Accuracy: {:.4f}\n" - "Evaluation Time [s]: {:.2f}".format(i, *results) - ) - print_separator = True - if results[1] < 0.5: - test_result = "FAIL" + for num_gpus, results in result_dic.items(): + functional_result, accuracy_result = results + + if print_separator: + print("---------------------------------------") + print_separator = True + + def optional_bool_to_pass_fail(b: Optional[bool]): + if b is None: + return "N/A" + return "PASS" if b else "FAIL" + + print(f"Number of GPUS: {num_gpus}") + + if functional_result is not None: + print(f"Functional Test: {optional_bool_to_pass_fail(functional_result.regular_pass)}") + print(f"Deployed Functional Test: {optional_bool_to_pass_fail(functional_result.deployed_pass)}") + + if functional_result.regular_pass == False: + functional_test_result = "FAIL" + if functional_result.deployed_pass == False: + functional_test_result = "FAIL" + + if accuracy_result is not None: + print(f"Model Accuracy: {accuracy_result.accuracy:.4f}") + print(f"Relaxed Model Accuracy: {accuracy_result.accuracy_relaxed:.4f}") + print(f"Deployed Model Accuracy: {accuracy_result.deployed_accuracy:.4f}") + print(f"Deployed Relaxed Model Accuracy: {accuracy_result.deployed_accuracy_relaxed:.4f}") + print(f"Evaluation Time [s]: {accuracy_result.evaluation_time:.2f}") + if accuracy_result.accuracy_relaxed < 0.5: + accuracy_test_result = "FAIL" print("=======================================") - print("TEST: " + test_result) - if test_result == "FAIL": + print(f"Functional: {functional_test_result}") + if args.run_accuracy: + print(f"Acccuracy: {accuracy_test_result}") + + if functional_test_result == "FAIL": + raise Exception("Functional test failed") + + if accuracy_test_result == "FAIL": raise Exception("Model accuracy is below 0.5") if __name__ == '__main__': - args = get_args() - run_inference_tests(args) + try: + args = get_args() + run_inference_tests(args) + except UsageError as e: + LOGGER.error(f"{e}") + except argparse.ArgumentError as e: + LOGGER.error(f"{e}")