From e97035da3c7deecd2cedf25d98c215d7f2e2c5f1 Mon Sep 17 00:00:00 2001 From: hamelsmu Date: Mon, 27 Nov 2023 18:22:40 +0000 Subject: [PATCH] save bench --- trt-bench/README.md | 195 +++++++++++++++++++++++++++++++ trt-bench/requests_bench.py | 70 +++++++++++ trt-bench/setup.sh | 20 ++++ trt-bench/throughput-bench.ipynb | 159 +++++++++++++++++++++++++ 4 files changed, 444 insertions(+) create mode 100644 trt-bench/README.md create mode 100644 trt-bench/requests_bench.py create mode 100755 trt-bench/setup.sh create mode 100644 trt-bench/throughput-bench.ipynb diff --git a/trt-bench/README.md b/trt-bench/README.md new file mode 100644 index 0000000..b4e5f2a --- /dev/null +++ b/trt-bench/README.md @@ -0,0 +1,195 @@ +# Nvidia Triton w/ TensorRT-LLM Backend + +Use the [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main) backend with the [Nvidia Triton Inference Server](https://github.com/triton-inference-server/server). + +The clearest end-to-end instructions I found was [this official blog post](https://developer.nvidia.com/blog/optimizing-inference-on-llms-with-tensorrt-llm-now-publicly-available/). + +## Build TensorRT-LLM container + +Follow [these instructions](https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.5.0/docs/source/installation.md) to build the docker container to compile the model. + +When you are done this will have created a docker image called `tensorrt_llm/release:latest ` locally. + +> Note: I had to fight nvidia-docker for this to work, I ended up having to uninstall Docker and anything related to nvidia container toolkit and re-install everything from scratch. + +## Pull the model from HuggingFace + +Make a directory called model_input and clone the Hugging Face model into it. + +```bash +mkdir model_input +# Make sure you have git-lfs installed (https://git-lfs.com) +cd model_input +git lfs install +git clone https://huggingface.co/meta-llama/Llama-2-7b-hf +``` + +## Compile the model + +To compile the model, mount the model you just pulled from HuggingFace and the model_output directory into the container and run the compile script. First, shell into the container like this: + +```bash +# Make an output directory to store the compiled model assets +mkdir model_output + +sudo docker run --gpus all --ulimit memlock=-1 --ipc=host --ulimit stack=67108864 -it -v ${PWD}/model_input:/model_input -v ${PWD}/model_output:/model_output tensorrt_llm/release:latest bash +``` + +Install the quantization toolkit per [these instructions](https://github.com/NVIDIA/TensorRT-LLM/tree/release/0.5.0/examples/quantization#tensorrt-llm-quantization-toolkit-installation-guide): + +```bash +cd /app/tensorrt_llm/examples/quantization +python -m pip install --upgrade pip +# Obtain the cuda version from the system. Assuming nvcc is available in path. +cuda_version=$(nvcc --version | grep 'release' | awk '{print $6}' | awk -F'[V.]' '{print $2$3}') +# Obtain the python version from the system. +python_version=$(python3 --version 2>&1 | awk '{print $2}' | awk -F. '{print $1$2}') +# Download and install the AMMO package from the DevZone. +wget https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.3.0.tar.gz +tar -xzf nvidia_ammo-0.3.0.tar.gz +pip install nvidia_ammo-0.3.0/nvidia_ammo-0.3.0+cu$cuda_version-cp$python_version-cp$python_version-linux_x86_64.whl +# Install the additional requirements +pip install -r requirements.txt +``` + +Then quantize the model, this took < 10 minutes on my RTX 6000 Ada (so be patient): + +```bash +# Quantize HF LLaMA 7B checkpoint into INT4 AWQ format +cd /app/tensorrt_llm/examples/llama +for sz in 7 13 70; do + python quantize.py --model_dir /model_input/Llama-2-${sz}b-chat-hf/ \ + --dtype float16 \ + --qformat int4_awq \ + --export_path ./llama-${sz}b-4bit-gs128-awq.pt \ + --calib_size 32 +done +``` + +Then, run the compile script. Make sure your GPU memory is free when you do this: + +```bash +cd /app/tensorrt_llm/examples/llama +# Compile the LLaMA models to TensorRT format +for sz in 7 13 70; do +sz=7 +python build.py --model_dir /model_input/Llama-2-${sz}b-chat-hf/ \ + --quant_ckpt_path ./llama-${sz}b-4bit-gs128-awq.pt \ + --dtype float16 \ + --use_gpt_attention_plugin float16 \ + --use_gemm_plugin float16 \ + --remove_input_padding \ + --use_inflight_batching \ + --paged_kv_cache \ + --use_weight_only \ + --weight_only_precision int4_awq \ + --max_batch_size 256 \ + --per_group \ + --output_dir /model_output/${sz}b \ + --world_size 4 \ + --tp_size 4 +done +``` + + +When you are done, exit the docker container. The compiled assets will be located in `model_output/`. You will see three files: + +- `llama_float16_tp1_rank0.engine`: The main output of the build script, containing the executable graph of operations with the model weights embedded. +- `config.json`: Includes detailed information about the model, like its general structure and precision, as well as information about which plug-ins were incorporated into the engine. +- `model.cache`: Caches some of the timing and optimization information from model compilation, making successive builds quicker. + + + +## Prepare the model repository + +The triton inference server works with model repositories that are specific directory structures with config files and other assets. You can read about model repositories [here](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_repository.html). The model repository for this example is quite complicated and involved setting up an ensemble of a preprocessing, model and postprocessing components along with lots of boilerplate code. + +The easiest way to get started is to clone the example repo and modify it to suit your needs. First clone the the repo: + +```bash +git clone -b release/0.5.0 https://github.com/triton-inference-server/tensorrtllm_backend.git +``` + +Copy the compiled model assets from `./model_output` into the model example repository: + +```bash +cp model_output/* tensorrtllm_backend/all_models/inflight_batcher_llm/tensorrt_llm/1/ +``` + +Then use their tools to modify the configuration files of all three components of the ensemble. Make sure you run these commands in the `tensorrtllm_backend` directory: + +```bash +cd tensorrtllm_backend +# modify config for the model +python3 tools/fill_template.py --in_place \ + all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt \ + decoupled_mode:true,engine_dir:/all_models/inflight_batcher_llm/tensorrt_llm/1,\ +max_tokens_in_paged_kv_cache:,batch_scheduler_policy:guaranteed_completion,kv_cache_free_gpu_mem_fraction:0.2,\ +max_num_sequences:4 +``` + +Next, modify config for the preprocessing component, modify the `tokenizer_dir` to point to a model on HuggingFace Hub you used, I am using `NousResearch/Llama-2-7b-hf` which is a replica of `meta-llama/Llama-2-7b-hf`, so we don't have to worry about the fiddly permissions on the original model. + +```bash +# modify config for the preprocessing component +python tools/fill_template.py --in_place \ + all_models/inflight_batcher_llm/preprocessing/config.pbtxt \ + tokenizer_type:llama,tokenizer_dir:NousResearch/Llama-2-7b-hf + +# modify config for the postprocessing component +python tools/fill_template.py --in_place \ + all_models/inflight_batcher_llm/postprocessing/config.pbtxt \ + tokenizer_type:llama,tokenizer_dir:NousResearch/Llama-2-7b-hf +``` + +## Prepare The Triton Server + +Next, we have to mount the model repository we just created into the Triton server and do some additional work interactively before it is ready. Make sure you are in the `tensorrtllm_backend` directory when running the following commands because we also need to mount the `scripts` directory into the container. + +```bash +sudo docker run -it --rm --gpus all --network host --shm-size=1g \ +-v $(pwd)/all_models:/all_models \ +-v $(pwd)/scripts:/opt/scripts \ +nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 bash +``` + +Next, in the Docker container, login to the HuggingFace Hub: + + +```bash +huggingface-cli login --token +``` + +Then, install the python dependencies: + +```bash +# Install python dependencies +pip install sentencepiece protobuf +``` + +Finally, start the Triton server: + +```bash +# Launch Server +python /opt/scripts/launch_triton_server.py --world_size 1 --model_repo /all_models/inflight_batcher_llm +``` + +> Note: if you get an error `Unexpected tokenizer type: ${tokenizer_type}` this means you didn't run the `fill_template.py` script on the preprocessing and postprocessing config files correctly. + +You will get output that looks like this: + +```bash +I1101 14:59:56.742506 113 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001 +I1101 14:59:56.742703 113 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 +I1101 14:59:56.828990 113 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 +``` + +### Test the server + +You can make a request with `curl` like this: + +```bash +curl -X POST localhost:8000/v2/models/ensemble/generate -d \ +'{"text_input": "How do I count to nine in French?", +"parameters": {"max_tokens": 100, "bad_words":[""],"stop_words":[""]}}' +``` diff --git a/trt-bench/requests_bench.py b/trt-bench/requests_bench.py new file mode 100644 index 0000000..9291d9f --- /dev/null +++ b/trt-bench/requests_bench.py @@ -0,0 +1,70 @@ +import asyncio +import time +import aiohttp +import statistics + +# Shared concurrency counter +current_concurrency = 0 + +async def send_request(session, url, data, request_number, response_record): + global current_concurrency + print(f"Starting request #{request_number}") + current_concurrency += 1 # Increment concurrency when request starts + start_time = time.perf_counter() + + async with session.post(url, json=data) as response: + await response.read() + + end_time = time.perf_counter() + latency = end_time - start_time + response_record.append((current_concurrency, latency)) + print(f"Finished request #{request_number}") + current_concurrency -= 1 # Decrement concurrency when request ends + +async def main(duration, requests_per_second, output_seq_len): + url = 'http://localhost:8000/v2/models/ensemble/generate' + data = { + "text_input": "How do I count to ten in French?", + "parameters": { + "max_tokens": output_seq_len, + "min_length": output_seq_len, + "bad_words": [""], + "stop_words": [""], + # "stream": True + } + } + + tasks = [] + response_record = [] + request_counter = 0 + + async with aiohttp.ClientSession() as session: + start_time = time.perf_counter() + while time.perf_counter() - start_time < duration: + request_counter += 1 + task = asyncio.create_task(send_request(session, url, data, request_counter, response_record)) + tasks.append(task) + await asyncio.sleep(1 / requests_per_second) + print(f"Current concurrency: {current_concurrency}") + + await asyncio.gather(*tasks) + + # Statistics + latencies = [item[1] for item in response_record] + average_latency = statistics.mean(latencies) + max_latency = max(latencies) + min_latency = min(latencies) + std_dev_latency = statistics.stdev(latencies) + + print(f"Average Latency: {average_latency:.4f} seconds") + print(f"Max Latency: {max_latency:.4f} seconds") + print(f"Min Latency: {min_latency:.4f} seconds") + print(f"Standard Deviation of Latency: {std_dev_latency:.4f} seconds") + + + +if __name__ == "__main__": + duration = 60 # Duration in seconds + requests_per_second = .3 # Requests per second + output_seq_len = 300 + asyncio.run(main(duration, requests_per_second, output_seq_len)) \ No newline at end of file diff --git a/trt-bench/setup.sh b/trt-bench/setup.sh new file mode 100755 index 0000000..206ed02 --- /dev/null +++ b/trt-bench/setup.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# TensorRT-LLM uses git-lfs, which needs to be installed in advance. +sudo apt-get update && sudo apt-get -y install git git-lfs + +git clone https://github.com/NVIDIA/TensorRT-LLM.git +cd TensorRT-LLM +git submodule update --init --recursive +git lfs install +git lfs pull + + +# See https://developer.nvidia.com/cuda-gpus#compute to find out which version +# I'm using a A100 for this particular setup so that is `80-real` +make -C docker release_build CUDA_ARCHS="80-real" + +cd .. +mkdir model_input +# Make sure you have git-lfs installed (https://git-lfs.com) +cd model_input +git clone https://huggingface.co/NousResearch/Llama-2-70b-chat-hf diff --git a/trt-bench/throughput-bench.ipynb b/trt-bench/throughput-bench.ipynb new file mode 100644 index 0000000..05e0d21 --- /dev/null +++ b/trt-bench/throughput-bench.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5a3bc225-c642-4c8d-b244-1a23bf6aa39d", + "metadata": {}, + "outputs": [], + "source": [ + "import requests, time\n", + "import threading\n", + "from transformers import AutoTokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(\"NousResearch/Llama-2-7b-hf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1cc9829d-ced8-46e0-8ece-9e738625d928", + "metadata": {}, + "outputs": [], + "source": [ + "def send_request(i):\n", + " global out\n", + " url = 'http://localhost:8000/v2/models/ensemble/generate'\n", + " data = {\n", + " \"text_input\": \"How do I count to nine in French?\",\n", + " \"parameters\": {\n", + " \"max_tokens\": 500,\n", + " \"bad_words\": [\"\"],\n", + " \"stop_words\": [\"\"],\n", + " \"temperature\": 0,\n", + " }\n", + " }\n", + " response = requests.post(url, json=data)\n", + " out[i] = response.json()['text_output']\n", + "\n", + "def concurrent_test(n_threads):\n", + " global out\n", + " out = [None] * n_threads # pre allocate a list\n", + " threads = []\n", + " for index in range(n_threads):\n", + " x = threading.Thread(target=send_request, args=(index,))\n", + " threads.append(x)\n", + " \n", + " start = time.perf_counter()\n", + " for t in threads: t.start()\n", + " for t in threads: t.join()\n", + " request_time = time.perf_counter() - start\n", + " toks = sum([len(tokenizer.encode(o)) for o in out])\n", + " return toks / request_time" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a1d31835-a86f-467b-a353-9fd031907b4a", + "metadata": {}, + "outputs": [], + "source": [ + "def measure(bs, n_times=3):\n", + " import numpy as np\n", + " m = [concurrent_test(bs) for _ in range(n_times)]\n", + " avg_toksec = np.mean(m)\n", + " avg_toksec_per_thread = avg_toksec / bs\n", + " print(f'\\n\\nConcurrent Requests={bs} (averaged over {n_times} separate experiments)\\n==============================\\ntok/sec total: {avg_toksec:.1f}\\ntok/sec per thread: {avg_toksec_per_thread:.1f}')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "048d94ea-de7d-4f0c-9101-1b89bdb7118c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Concurrent Requests=1 (averaged over 3 separate experiments)\n", + "==============================\n", + "tok/sec total: 185.3\n", + "tok/sec per thread: 185.3\n", + "\n", + "\n", + "Concurrent Requests=2 (averaged over 3 separate experiments)\n", + "==============================\n", + "tok/sec total: 355.7\n", + "tok/sec per thread: 177.9\n", + "\n", + "\n", + "Concurrent Requests=4 (averaged over 3 separate experiments)\n", + "==============================\n", + "tok/sec total: 638.5\n", + "tok/sec per thread: 159.6\n", + "\n", + "\n", + "Concurrent Requests=8 (averaged over 3 separate experiments)\n", + "==============================\n", + "tok/sec total: 958.1\n", + "tok/sec per thread: 119.8\n", + "\n", + "\n", + "Concurrent Requests=16 (averaged over 3 separate experiments)\n", + "==============================\n", + "tok/sec total: 961.5\n", + "tok/sec per thread: 60.1\n", + "\n", + "\n", + "Concurrent Requests=32 (averaged over 3 separate experiments)\n", + "==============================\n", + "tok/sec total: 962.1\n", + "tok/sec per thread: 30.1\n", + "\n", + "\n", + "Concurrent Requests=64 (averaged over 3 separate experiments)\n", + "==============================\n", + "tok/sec total: 963.0\n", + "tok/sec per thread: 15.0\n" + ] + } + ], + "source": [ + "for bs in [1,2,4,8,16,32,64]:\n", + " measure(bs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "425dd446-d0e8-4ef6-a21b-07aedfe0a8da", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}