Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative Decoding Interface #1241

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3b6f0ce
spec dec config
joostinyi Nov 5, 2024
272d561
add optional dict of trt llm configs
joostinyi Nov 5, 2024
9bf2f08
Merge branch 'main' into jyi/spec-dec-interface
joostinyi Nov 5, 2024
d8376dc
fix bad merge
joostinyi Nov 6, 2024
edca16e
add extensions support
joostinyi Nov 6, 2024
d093ded
fix fixture
joostinyi Nov 6, 2024
f489421
cli push fixes
joostinyi Nov 6, 2024
057967b
constants
joostinyi Nov 6, 2024
f73d67a
fix ordering
joostinyi Nov 6, 2024
0854bcc
Merge branch 'main' into jyi/spec-dec-interface
joostinyi Nov 7, 2024
d208245
fix merge
joostinyi Nov 7, 2024
41e9e0a
refactor interface
joostinyi Nov 12, 2024
adcac8c
add tp validation error
joostinyi Nov 12, 2024
50138e9
self review
joostinyi Nov 12, 2024
ac9e99f
use constant
joostinyi Nov 12, 2024
bb1e93c
fix tests
joostinyi Nov 13, 2024
d1907e1
fix tests
joostinyi Nov 19, 2024
76fe148
Merge branch 'main' into jyi/spec-dec-interface
joostinyi Nov 19, 2024
2354d52
Merge branch 'main' into jyi/spec-dec-interface
joostinyi Nov 19, 2024
69f37e7
add request_default_max_tokens
joostinyi Nov 19, 2024
6c49529
fix default on trtllm runtime
joostinyi Nov 20, 2024
e316609
update copy
joostinyi Nov 20, 2024
d4c36a3
Merge branch 'main' into jyi/spec-dec-interface
joostinyi Nov 22, 2024
8597eaf
bump to 54rc0
joostinyi Nov 22, 2024
ab0f684
add total token limit to toplevel config
joostinyi Nov 22, 2024
3a149d8
bump briton to 0.3.10
joostinyi Nov 22, 2024
8b69330
fix import
joostinyi Nov 25, 2024
cd6bf08
54rc2
joostinyi Nov 25, 2024
11c38ab
fix rc3
joostinyi Nov 25, 2024
3641e68
rc4
joostinyi Nov 25, 2024
567a644
bump briton server image
joostinyi Nov 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.53"
version = "0.9.54rc5"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
6 changes: 4 additions & 2 deletions truss/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@

REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_"

TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0_v0.0.17"
TRTLLM_SPEC_DEC_TARGET_MODEL_NAME = "target"
TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft"
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0-4fd8a10-5e5c3d7"
TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.9"]
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.10"]
AUDIO_MODEL_TRTLLM_REQUIREMENTS = [
"--extra-index-url https://pypi.nvidia.com",
"tensorrt_cu12_bindings==10.2.0.post1",
Expand Down
84 changes: 54 additions & 30 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class TrussTRTLLMBatchSchedulerPolicy(str, Enum):
GUARANTEED_NO_EVICT = "guaranteed_no_evict"


class TrussSpecDecMode(str, Enum):
DRAFT_EXTERNAL: str = "DRAFT_TOKENS_EXTERNAL"


class TrussTRTLLMBuildConfiguration(BaseModel):
base_model: TrussTRTLLMModel
max_seq_len: int
Expand All @@ -73,13 +77,9 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
plugin_configuration: TrussTRTLLMPluginConfiguration = (
TrussTRTLLMPluginConfiguration()
)
kv_cache_free_gpu_mem_fraction: float = 0.9
num_builder_gpus: Optional[int] = None
enable_chunked_context: bool = False
batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = (
TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT
)
default_max_tokens: Optional[int] = None
speculative_decoding_mode: Optional[TrussSpecDecMode] = None
max_draft_len: Optional[int] = None

@validator("max_beam_width")
def check_max_beam_width(cls, v: int):
Expand All @@ -91,40 +91,26 @@ def check_max_beam_width(cls, v: int):
return v


class TrussTRTLLMServingConfiguration(BaseModel):
engine_repository: str
tokenizer_repository: str
tensor_parallel_count: int = 1
pipeline_parallel_count: int = 1
class TrussTRTLLMRuntimeConfiguration(BaseModel):
kv_cache_free_gpu_mem_fraction: float = 0.9
enable_chunked_context: bool = False
num_draft_tokens: Optional[int] = None
batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = (
TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT
)
request_default_max_tokens: Optional[int] = None


class TRTLLMConfiguration(BaseModel):
serve: Optional[TrussTRTLLMServingConfiguration] = None
build: Optional[TrussTRTLLMBuildConfiguration] = None
runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
build: TrussTRTLLMBuildConfiguration

def __init__(self, **data):
super().__init__(**data)
self._validate_minimum_required_configuration()
self._validate_kv_cache_flags()
if self.build.checkpoint_repository.source == CheckpointSource.HF:
self._validate_hf_repo_id()

# In pydantic v2 this would be `@model_validator(mode="after")` and
# the __init__ override can be removed.
def _validate_minimum_required_configuration(self):
if not self.serve and not self.build:
raise ValueError("Either serve or build configurations must be provided")
if self.serve and self.build:
raise ValueError("Both serve and build configurations cannot be provided")
if self.serve is not None:
if (self.serve.engine_repository is None) ^ (
self.serve.tokenizer_repository is None
):
raise ValueError(
"Both engine_repository and tokenizer_repository must be provided"
)
return self

def _validate_kv_cache_flags(self):
if self.build is None:
return self
Expand Down Expand Up @@ -160,3 +146,41 @@ def requires_build(self):
# when pydantic v2 is used here
def to_json_dict(self, verbose=True):
return json.loads(self.json(exclude_unset=not verbose))


class TRTLLMSpeculativeDecodingConfiguration(BaseModel):
target: TRTLLMConfiguration
draft: TRTLLMConfiguration
total_token_limit: int = 500000

def __init__(self, **data):
super().__init__(**data)
self._spec_dec_configs = [
self.target.build.speculative_decoding_mode,
self.target.build.max_draft_len,
] + (
[self.draft.runtime.num_draft_tokens]
if self.draft.runtime and self.draft.runtime.num_draft_tokens
else [False]
)
self._validate_spec_dec()

def _validate_spec_dec(self):
if any(self._spec_dec_configs):
if not all(self._spec_dec_configs):
raise ValueError(
"Speculative decoding requires all of `target.build.speculative_decoding_mode`, `target.build.max_draft_len`, and `draft.runtime.num_draft_tokens` to be configured."
)
for trt_llm_config in [self.target, self.draft]:
if trt_llm_config.build.base_model is TrussTRTLLMModel.WHISPER:
raise ValueError("Speculative decoding for Whisper is not supported.")
if (
self.target.build.tensor_parallel_count
!= self.draft.build.tensor_parallel_count
):
raise ValueError(
"Speculative decoding requires the same tensor parallelism for target and draft models."
)

def to_json_dict(self, verbose=True):
return json.loads(self.json(exclude_unset=not verbose))
46 changes: 35 additions & 11 deletions truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
from dataclasses import _MISSING_TYPE, dataclass, field, fields
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

import yaml

from truss.base.constants import HTTP_PUBLIC_BLOB_BACKEND
from truss.base.constants import (
HTTP_PUBLIC_BLOB_BACKEND,
TRTLLM_SPEC_DEC_TARGET_MODEL_NAME,
)
from truss.base.custom_types import ModelFrameworkType
from truss.base.errors import ValidationError
from truss.base.trt_llm_config import TRTLLMConfiguration, TrussTRTLLMQuantizationType
from truss.base.trt_llm_config import (
TRTLLMConfiguration,
TRTLLMSpeculativeDecodingConfiguration,
TrussTRTLLMQuantizationType,
)
from truss.base.validation import (
validate_cpu_spec,
validate_memory_spec,
Expand Down Expand Up @@ -558,7 +565,9 @@ class TrussConfig:
base_image: Optional[BaseImage] = None
docker_server: Optional[DockerServer] = None
model_cache: ModelCache = field(default_factory=ModelCache)
trt_llm: Optional[TRTLLMConfiguration] = None
trt_llm: Optional[
Union[TRTLLMConfiguration, TRTLLMSpeculativeDecodingConfiguration]
] = None
build_commands: List[str] = field(default_factory=list)
use_local_chains_src: bool = False

Expand All @@ -571,6 +580,14 @@ def canonical_python_version(self) -> str:
"py38": "3.8",
}[self.python_version]

@property
def parsed_trt_llm_configs(self) -> List[TRTLLMConfiguration]:
if self.trt_llm:
if isinstance(self.trt_llm, TRTLLMSpeculativeDecodingConfiguration):
return [self.trt_llm.target, self.trt_llm.draft]
return [self.trt_llm]
return []

@staticmethod
def from_dict(d):
config = TrussConfig(
Expand Down Expand Up @@ -617,7 +634,10 @@ def from_dict(d):
ModelCache.from_list,
),
trt_llm=transform_optional(
d.get("trt_llm"), lambda x: TRTLLMConfiguration(**x)
d.get("trt_llm"),
lambda x: (TRTLLMConfiguration(**x))
if TRTLLM_SPEC_DEC_TARGET_MODEL_NAME not in d.get("trt_llm")
else (TRTLLMSpeculativeDecodingConfiguration(**x)),
),
build_commands=d.get("build_commands", []),
use_local_chains_src=d.get("use_local_chains_src", False),
Expand Down Expand Up @@ -670,17 +690,17 @@ def to_dict(self, verbose: bool = True):
def clone(self):
return TrussConfig.from_dict(self.to_dict())

def _validate_accelerator_for_trt_llm_builder(self) -> None:
if self.trt_llm and self.trt_llm.build:
def _validate_trt_llm_config(self) -> None:
for trt_llm_config in self.parsed_trt_llm_configs:
if (
self.trt_llm.build.quantization_type
trt_llm_config.build.quantization_type
is TrussTRTLLMQuantizationType.WEIGHTS_ONLY_INT8
and self.resources.accelerator.accelerator is Accelerator.A100
):
raise ValueError(
"Weight only int8 quantization on A100 accelerators is not currently supported"
)
elif self.trt_llm.build.quantization_type in [
elif trt_llm_config.build.quantization_type in [
TrussTRTLLMQuantizationType.FP8,
TrussTRTLLMQuantizationType.FP8_KV,
] and self.resources.accelerator.accelerator not in [
Expand All @@ -691,7 +711,7 @@ def _validate_accelerator_for_trt_llm_builder(self) -> None:
raise ValueError(
"FP8 quantization is only supported on L4 and H100 accelerators"
)
tensor_parallel_count = self.trt_llm.build.tensor_parallel_count
tensor_parallel_count = trt_llm_config.build.tensor_parallel_count

if tensor_parallel_count != self.resources.accelerator.count:
raise ValueError(
Expand Down Expand Up @@ -720,7 +740,7 @@ def validate(self):
raise ValueError(
"Please ensure that only one of `requirements` and `requirements_file` is specified"
)
self._validate_accelerator_for_trt_llm_builder()
self._validate_trt_llm_config()


def _handle_env_vars(env_vars: Dict[str, Any]) -> Dict[str, str]:
Expand Down Expand Up @@ -796,6 +816,10 @@ def obj_to_dict(obj, verbose: bool = False):
d["trt_llm"] = transform_optional(
field_curr_value, lambda data: data.to_json_dict(verbose=verbose)
)
elif isinstance(field_curr_value, TRTLLMSpeculativeDecodingConfiguration):
d["trt_llm"] = transform_optional(
field_curr_value, lambda data: data.to_json_dict(verbose=verbose)
)
elif isinstance(field_curr_value, BaseImage):
d["base_image"] = transform_optional(
field_curr_value, lambda data: data.to_dict()
Expand Down
38 changes: 19 additions & 19 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
from truss.remote.baseten.utils.status import get_displayable_status
from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory
from truss.trt_llm.config_checks import (
check_and_update_memory_for_trt_llm_builder,
check_secrets_for_trt_llm_builder,
is_missing_secrets_for_trt_llm_builder,
memory_updated_for_trt_llm_builder,
uses_trt_llm_builder,
)
from truss.truss_handle.build import cleanup as _cleanup
Expand Down Expand Up @@ -1150,32 +1150,32 @@ def push(
live_reload_disabled_text = "Development mode is currently not supported for trusses using TRT-LLM build flow, push as a published model using --publish"
console.print(live_reload_disabled_text, style="red")
sys.exit(1)
if not check_secrets_for_trt_llm_builder(tr):
if is_missing_secrets_for_trt_llm_builder(tr):
missing_token_text = (
"`hf_access_token` must be provided in secrets to build a gated model. "
"Please see https://docs.baseten.co/deploy/guides/private-model for configuration instructions."
)
console.print(missing_token_text, style="red")
sys.exit(1)
if not check_and_update_memory_for_trt_llm_builder(tr):
if memory_updated_for_trt_llm_builder(tr):
console.print(
f"Automatically increasing memory for trt-llm builder to {TRTLLM_MIN_MEMORY_REQUEST_GI}Gi."
)
config = tr.spec.config
if (
config.trt_llm.build.quantization_type
in [TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV]
and not config.trt_llm.build.num_builder_gpus
):
fp8_and_num_builder_gpus_text = (
"Warning: build specifies FP8 quantization but does not explicitly specify number of build GPUs. "
"GPU memory required at build time may be significantly more than that required at inference time due to FP8 quantization, which can result in OOM failures during the engine build phase."
"`num_builder_gpus` can be used to specify the number of GPUs to use at build time."
)
console.print(
fp8_and_num_builder_gpus_text,
style="yellow",
)
for trt_llm_config in tr.spec.config.parsed_trt_llm_configs:
if (
trt_llm_config.build.quantization_type
in [TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV]
and not trt_llm_config.build.num_builder_gpus
):
fp8_and_num_builder_gpus_text = (
"Warning: build specifies FP8 quantization but does not explicitly specify number of build GPUs. "
"GPU memory required at build time may be significantly more than that required at inference time due to FP8 quantization, which can result in OOM failures during the engine build phase."
"`num_builder_gpus` can be used to specify the number of GPUs to use at build time."
)
console.print(
fp8_and_num_builder_gpus_text,
style="yellow",
)

# TODO(Abu): This needs to be refactored to be more generic
service = remote_provider.push(
Expand Down
Loading
Loading