Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --backend TRT_LLM_BUILDER option to truss init #1067

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions truss/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@

import yaml

from truss.config.trt_llm import (
CheckpointRepository,
CheckpointSource,
TRTLLMConfiguration,
TrussTRTLLMBuildConfiguration,
)
from truss.constants import CONFIG_FILE, TEMPLATES_DIR, TRUSS
from truss.docker import kill_containers
from truss.model_inference import infer_python_version, map_to_supported_python_version
from truss.notebook import is_notebook_or_ipython
from truss.truss_config import Build, TrussConfig
from truss.truss_config import Accelerator, AcceleratorSpec, Build, TrussConfig
from truss.truss_handle import TrussHandle
from truss.util.path import build_truss_target_directory, copy_tree_path

Expand Down Expand Up @@ -54,6 +60,24 @@ def populate_target_directory(
return target_directory_path_typed


def set_trtllm_engine_builder_config(config):
config.resources.accelerator = AcceleratorSpec(
accelerator=Accelerator("A10G"), count=1
)
config.resources.use_gpu = True
trt_llm_build = TrussTRTLLMBuildConfiguration(
base_model="llama",
max_input_len=1024,
max_output_len=1024,
max_batch_size=1,
max_beam_width=1,
checkpoint_repository=CheckpointRepository(
source=CheckpointSource("HF"), repo=""
),
)
config.trt_llm = TRTLLMConfiguration(build=trt_llm_build)


def init(
target_directory: str,
data_files: Optional[List[str]] = None,
Expand All @@ -77,12 +101,19 @@ def init(
python_version=map_to_supported_python_version(infer_python_version()),
)

if build_config:
if build_config and build_config.model_server.value != "TRT_LLM_BUILDER":
config.build = build_config

if build_config.model_server.value == "TRT_LLM_BUILDER":
template = "trtllm-engine-builder"
set_trtllm_engine_builder_config(config)
else:
template = "custom"

target_directory_path = populate_target_directory(
config=config,
target_directory_path=target_directory,
template=template,
populate_dirs=True,
)

Expand Down
4 changes: 3 additions & 1 deletion truss/config/trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def _validate_minimum_required_configuration(self):
if not self.serve and not self.build:
raise ValueError("Either serve or build configurations must be provided")
if self.serve and self.build:
raise ValueError("Both serve and build configurations cannot be provided")
raise ValueError(
"One of serve XOR build configurations must be provided, not both"
)
if self.serve is not None:
if (self.serve.engine_repository is None) ^ (
self.serve.tokenizer_repository is None
Expand Down
Empty file.
30 changes: 30 additions & 0 deletions truss/templates/trtllm-engine-builder/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
The `Model` class is allows you to customize the behavior of your TensorRT-LLM engine.

The main methods to implement here are:
* `load`: runs exactly once when the model server is spun up or patched and loads the
model onto the model server. Include any logic for initializing your model server.
* `predict`: runs every time the model server is called. Include any logic for model
inference and return the model output.

See https://docs.baseten.co/performance/engine-builder-customization for more.
"""


class Model:
def __init__(self, trt_llm, **kwargs):
# Uncomment the following to get access
# to various parts of the Truss config.

# self._data_dir = kwargs["data_dir"]
# self._config = kwargs["config"]
# self._secrets = kwargs["secrets"]
self._engine = trt_llm["engine"]

def load(self):
# Load
pass

async def predict(self, model_input):
# Run model inference here
return await self._engine.predict(model_input)
1 change: 1 addition & 0 deletions truss/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class ModelServer(Enum):

TrussServer = "TrussServer"
TRT_LLM = "TRT_LLM"
TRT_LLM_BUILDER = "TRT_LLM_BUILDER"


@dataclass
Expand Down
Loading