forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Export implementation for vLLM 0.4.3. Supports LLAMA2, Mistral, Mixtral (unverified), Gemma and StarCoder2 models. The nemo.export.tensorrt_llm alias was removed to avoid initializing TRT-LLM when importing anything from nemo.export. Signed-off-by: Alexey Panteleev <[email protected]> * Fixed some CodeQL warnings. Signed-off-by: Alexey Panteleev <[email protected]> * Apply isort and black reformatting Signed-off-by: apanteleev <[email protected]> * Removed empty files. Signed-off-by: Alexey Panteleev <[email protected]> * Apply isort and black reformatting Signed-off-by: apanteleev <[email protected]> * Updated the integration for vLLM 0.5.0. Signed-off-by: Alexey Panteleev <[email protected]> * Updated the vLLM deployment interface to use max_output_len instead of max_output_token. Signed-off-by: Alexey Panteleev <[email protected]> * Apply isort and black reformatting Signed-off-by: apanteleev <[email protected]> * Moved the Exporter class to nemo/export and renamed its file to vllm_exporter.py, to be more similar to TRT-LLM. Signed-off-by: Alexey Panteleev <[email protected]> * Apply isort and black reformatting Signed-off-by: apanteleev <[email protected]> * Implemented vLLM support in the export tests, added functional testing, implemented forward evaluation on vLLM without Triton. Signed-off-by: Alexey Panteleev <[email protected]> * Apply isort and black reformatting Signed-off-by: apanteleev <[email protected]> * Moved the vLLM deployment functionality to the common deploy_triton.py script. Signed-off-by: Alexey Panteleev <[email protected]> * Apply isort and black reformatting Signed-off-by: apanteleev <[email protected]> * Fixed the CodeQL discovered issues. Signed-off-by: Alexey Panteleev <[email protected]> * Apply isort and black reformatting Signed-off-by: apanteleev <[email protected]> * Fixed one more return of a wrong dimensionality... Signed-off-by: Alexey Panteleev <[email protected]> * More wrong dimensionality returns. Signed-off-by: Alexey Panteleev <[email protected]> --------- Signed-off-by: Alexey Panteleev <[email protected]> Signed-off-by: apanteleev <[email protected]> Co-authored-by: apanteleev <[email protected]> Co-authored-by: Onur Yilmaz <[email protected]>
- Loading branch information
1 parent
35fb010
commit 9e979d4
Showing
20 changed files
with
1,645 additions
and
167 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
from vllm import LLMEngine | ||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import TokenizerGroup | ||
|
||
from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer | ||
from nemo.export.tarutils import TarPath | ||
from nemo.export.vllm.tokenizer_group import NemoTokenizerGroup | ||
|
||
LOGGER = logging.getLogger("NeMo") | ||
|
||
|
||
class NemoLLMEngine(LLMEngine): | ||
""" | ||
Overrides some functionality from vllm.LLMEngine to use our custom tokenizer | ||
instead of one from Transformers. | ||
""" | ||
|
||
def _init_tokenizer(self, **tokenizer_init_kwargs): | ||
# Find the tokenizer file name in the Nemo checkpoint config | ||
tokenizer_config = self.model_config.nemo_model_config.get('tokenizer', {}) | ||
tokenizer_model = tokenizer_config.get('model', tokenizer_config.get('tokenizer_model', None)) | ||
|
||
# If there is no tokenizer file specified but there's a reference to an HF tokenizer, use that | ||
if tokenizer_model is None and tokenizer_config.get('library') == 'huggingface': | ||
tokenizer_type = tokenizer_config.get('type') | ||
if tokenizer_type is not None: | ||
tokenizer_group = TokenizerGroup( | ||
tokenizer_id=tokenizer_type, | ||
enable_lora=bool(self.lora_config), | ||
max_num_seqs=self.scheduler_config.max_num_seqs, | ||
max_input_length=None, | ||
) | ||
|
||
# Update the HF config fields that come from the tokenizer in NeMo | ||
self.model_config.hf_config.vocab_size = tokenizer_group.tokenizer.vocab_size | ||
self.model_config.hf_config.bos_token_id = tokenizer_group.tokenizer.bos_token_id | ||
self.model_config.hf_config.eos_token_id = tokenizer_group.tokenizer.eos_token_id | ||
self.model_config.hf_config.pad_token_id = tokenizer_group.tokenizer.pad_token_id | ||
|
||
return tokenizer_group | ||
|
||
# Open the checkpoint archive | ||
with TarPath(self.model_config.nemo_checkpoint) as archive: | ||
tokenizer_model_file = None | ||
if isinstance(tokenizer_model, str) and tokenizer_model.startswith('nemo:'): | ||
tokenizer_model = tokenizer_model[len('nemo:') :] | ||
tokenizer_model_file = archive / tokenizer_model | ||
if not tokenizer_model_file.exists(): | ||
LOGGER.warn( | ||
f'Tokenizer model file {tokenizer_model} specified in the model_config does not ' | ||
+ 'exist in the checkpoint.' | ||
) | ||
tokenizer_model_file = None | ||
|
||
if tokenizer_model_file is None: | ||
for path in archive.glob('*tokenizer*.model'): | ||
LOGGER.info(f'Found tokenizer model file {path}.') | ||
tokenizer_model_file = path | ||
break | ||
|
||
if tokenizer_model_file is None: | ||
raise RuntimeError('No tokenizer model file found, aborting.') | ||
|
||
# Extract the tokenizer model file into the model directory, | ||
# because sentencepiece cannot load it directly from TarPath. | ||
extracted_tokenizer_model = Path(self.model_config.model) / 'tokenizer.model' | ||
with tokenizer_model_file.open('rb') as infile: | ||
with extracted_tokenizer_model.open('wb') as outfile: | ||
outfile.write(infile.read()) | ||
|
||
# Construct the tokenizer object and wrapper | ||
tokenizer = SentencePieceTokenizer(str(extracted_tokenizer_model)) | ||
|
||
# Determine if the model needs a bos token (which is not stored in Nemo checkpoints) | ||
add_bos_token = self.model_config.model_converter.requires_bos_token() | ||
|
||
tokenizer_group = NemoTokenizerGroup(tokenizer, add_bos_token=add_bos_token) | ||
|
||
# Update the HF config fields that come from the tokenizer in NeMo | ||
self.model_config.hf_config.vocab_size = tokenizer.vocab_size | ||
self.model_config.hf_config.bos_token_id = tokenizer.bos_token_id | ||
self.model_config.hf_config.eos_token_id = tokenizer.eos_token_id | ||
self.model_config.hf_config.pad_token_id = tokenizer.pad_id | ||
|
||
return tokenizer_group |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional, Union | ||
|
||
import torch | ||
import yaml | ||
from transformers import AutoConfig | ||
from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len | ||
from vllm.transformers_utils.config import get_hf_text_config | ||
|
||
from nemo.export.tarutils import TarPath | ||
from nemo.export.vllm.model_converters import get_model_converter | ||
|
||
|
||
class NemoModelConfig(ModelConfig): | ||
""" | ||
This class pretents to be a vllm.config.ModelConfig (with extra fields) but skips | ||
some of its initialization code, and initializes the configuration from a Nemo checkpoint instead. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
nemo_checkpoint: str, | ||
model_dir: str, | ||
model_type: str, | ||
tokenizer_mode: str, | ||
dtype: Union[str, torch.dtype], | ||
seed: int, | ||
revision: Optional[str] = None, | ||
code_revision: Optional[str] = None, | ||
rope_scaling: Optional[dict] = None, | ||
rope_theta: Optional[float] = None, | ||
tokenizer_revision: Optional[str] = None, | ||
max_model_len: Optional[int] = None, | ||
quantization: Optional[str] = None, | ||
quantization_param_path: Optional[str] = None, | ||
enforce_eager: bool = False, | ||
max_seq_len_to_capture: Optional[int] = None, | ||
max_logprobs: int = 5, | ||
disable_sliding_window: bool = False, | ||
) -> None: | ||
# Don't call ModelConfig.__init__ because we don't want it to call | ||
# transformers.AutoConfig.from_pretrained(...) | ||
|
||
# TODO: Do something about vLLM's call to _load_generation_config_dict in LLMEngine.__init__ | ||
# because it calls transformers.GenerationConfig.from_pretrained(...), which tries to download things | ||
|
||
self.nemo_checkpoint = nemo_checkpoint | ||
self.model = model_dir | ||
self.model_type = model_type | ||
self.tokenizer = None | ||
self.tokenizer_mode = tokenizer_mode | ||
self.skip_tokenizer_init = False | ||
self.trust_remote_code = False | ||
self.seed = seed | ||
self.revision = revision | ||
self.code_revision = code_revision | ||
self.rope_scaling = rope_scaling | ||
self.rope_theta = rope_theta | ||
self.tokenizer_revision = tokenizer_revision | ||
self.quantization = quantization | ||
self.quantization_param_path = quantization_param_path | ||
self.enforce_eager = enforce_eager | ||
self.max_seq_len_to_capture = max_seq_len_to_capture | ||
self.max_logprobs = max_logprobs | ||
self.disable_sliding_window = disable_sliding_window | ||
self.served_model_name = nemo_checkpoint | ||
|
||
self.model_converter = get_model_converter(model_type) | ||
if self.model_converter is None: | ||
raise RuntimeError(f'Unknown model type "{model_type}"') | ||
|
||
hf_to_nemo_dict = { | ||
'hidden_size': 'hidden_size', | ||
'intermediate_size': 'ffn_hidden_size', | ||
'num_hidden_layers': 'num_layers', | ||
'num_attention_heads': 'num_attention_heads', | ||
'num_key_value_heads': 'num_query_groups', | ||
# 'hidden_act': 'activation', ## <- vLLM has good defaults for the models, nemo values are wrong | ||
'max_position_embeddings': ['max_position_embeddings', 'encoder_seq_length'], | ||
'rms_norm_eps': 'layernorm_epsilon', | ||
'attention_dropout': 'attention_dropout', | ||
'initializer_range': 'init_method_std', | ||
'norm_epsilon': 'layernorm_epsilon', | ||
'rope_theta': 'rotary_base', | ||
'use_bias': 'bias', | ||
} | ||
|
||
with TarPath(nemo_checkpoint) as archive: | ||
with (archive / "model_config.yaml").open("r") as model_config_file: | ||
self.nemo_model_config = yaml.load(model_config_file, Loader=yaml.SafeLoader) | ||
|
||
hf_args = {} | ||
for hf_arg, nemo_arg in hf_to_nemo_dict.items(): | ||
if not isinstance(nemo_arg, list): | ||
nemo_arg = [nemo_arg] | ||
|
||
for nemo_arg_option in nemo_arg: | ||
value = self.nemo_model_config.get(nemo_arg_option) | ||
if value is not None: | ||
hf_args[hf_arg] = value | ||
break | ||
|
||
self.model_converter.convert_config(self.nemo_model_config, hf_args) | ||
|
||
self.hf_config = AutoConfig.for_model(model_type, **hf_args) | ||
|
||
self.hf_config.architectures = [self.model_converter.get_architecture()] | ||
if self.rope_scaling is not None: | ||
self.hf_config['rope_scaling'] = rope_scaling | ||
|
||
self.hf_text_config = get_hf_text_config(self.hf_config) | ||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) | ||
self.max_model_len = _get_and_verify_max_len( | ||
hf_config=self.hf_text_config, | ||
max_model_len=max_model_len, | ||
disable_sliding_window=self.disable_sliding_window, | ||
sliding_window_len=self.get_hf_config_sliding_window(), | ||
) | ||
self._verify_tokenizer_mode() | ||
self._verify_embedding_mode() | ||
self._verify_quantization() | ||
self._verify_cuda_graph() |
Oops, something went wrong.