Skip to content

Commit

Permalink
Merge branch 'main' into ci-shortfin-matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
marbre authored Nov 25, 2024
2 parents 6a3afc6 + e906b66 commit 5be4210
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 33 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-sglang-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
runs-on: mi300x-4
defaults:
run:
shell: bash
Expand Down Expand Up @@ -78,7 +78,7 @@ jobs:
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"

- name: Launch Shortfin Server
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-sglang-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
runs-on: mi300x-4
defaults:
run:
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-shark-ai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }}

- name: Install pip deps
run: |
Expand Down
5 changes: 5 additions & 0 deletions app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,32 @@
import pytest
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from integration_tests.llm.utils import compile_model, export_paged_llm_v1
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
)
from integration_tests.llm.utils import (
compile_model,
export_paged_llm_v1,
download_with_hf_datasets,
)


@pytest.fixture(scope="module")
def pre_process_model(request, tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")

model_path = request.param["model_path"]
model_name = request.param["model_name"]
model_param_file_name = request.param["model_param_file_name"]
settings = request.param["settings"]
batch_sizes = request.param["batch_sizes"]

mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"

model_path = tmp_dir / model_param_file_name
download_with_hf_datasets(tmp_dir, model_name)

export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)

config = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import logging
import multiprocessing
import os
Expand All @@ -16,14 +15,14 @@
pytest.importorskip("sglang")
from sglang import bench_serving

from utils import SGLangBenchmarkArgs
from .utils import SGLangBenchmarkArgs, log_jsonl_result

from integration_tests.llm.utils import (
find_available_port,
start_llm_server,
)

logger = logging.getLogger("__name__")
logger = logging.getLogger(__name__)

device_settings = {
"device_flags": [
Expand All @@ -33,46 +32,40 @@
"device": "hip",
}

# TODO: Download on demand instead of assuming files exist at this path
MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa")
TOKENIZER_DIR = Path("/data/llama3.1/8b/")


def log_jsonl_result(file_path):
with open(file_path, "r") as file:
json_string = file.readline().strip()

json_data = json.loads(json_string)
for key, val in json_data.items():
logger.info(f"{key.upper()}: {val}")


@pytest.mark.parametrize(
"request_rate",
[1, 2, 4, 8, 16, 32],
"request_rate,model_param_file_name",
[
(req_rate, "meta-llama-3.1-8b-instruct.f16.gguf")
for req_rate in [1, 2, 4, 8, 16, 32]
],
)
@pytest.mark.parametrize(
"pre_process_model",
[
(
{
"model_path": MODEL_PATH,
"model_name": "llama3_8B_fp16",
"model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf",
"settings": device_settings,
"batch_sizes": [1, 4],
}
)
],
indirect=True,
)
def test_sglang_benchmark_server(request_rate, pre_process_model):
def test_sglang_benchmark_server(
request_rate, model_param_file_name, pre_process_model
):
# TODO: Remove when multi-device is fixed
os.environ["ROCR_VISIBLE_DEVICES"] = "1"

tmp_dir = pre_process_model

config_path = tmp_dir / "config.json"
vmfb_path = tmp_dir / "model.vmfb"
tokenizer_path = TOKENIZER_DIR / "tokenizer.json"
tokenizer_path = tmp_dir / "tokenizer.json"
model_path = tmp_dir / model_param_file_name

# Start shortfin llm server
port = find_available_port()
Expand All @@ -81,7 +74,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
tokenizer_path,
config_path,
vmfb_path,
MODEL_PATH,
model_path,
device_settings,
timeout=30,
)
Expand All @@ -91,7 +84,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
backend="shortfin",
num_prompt=10,
base_url=f"http://localhost:{port}",
tokenizer=TOKENIZER_DIR,
tokenizer=tmp_dir,
request_rate=request_rate,
)
output_file = (
Expand All @@ -116,7 +109,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
logger.info("======== RESULTS ========")
log_jsonl_result(benchmark_args.output_file)
except Exception as e:
logger.info(e)
logger.error(e)

server_process.terminate()
server_process.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from argparse import Namespace
from dataclasses import dataclass
import json
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


@dataclass
class SGLangBenchmarkArgs:
Expand Down Expand Up @@ -54,3 +58,12 @@ def __repr__(self):
f"Tokenizer: {self.tokenizer}\n"
f"Request Rate: {self.request_rate}"
)


def log_jsonl_result(file_path):
with open(file_path, "r") as file:
json_string = file.readline().strip()

json_data = json.loads(json_string)
for key, val in json_data.items():
logger.info(f"{key.upper()}: {val}")
2 changes: 1 addition & 1 deletion app_tests/integration_tests/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests
from transformers import AutoTokenizer

logger = logging.getLogger("__name__")
logger = logging.getLogger(__name__)


class AccuracyValidationException(RuntimeError):
Expand Down
123 changes: 122 additions & 1 deletion docs/amdgpu_kernel_optimization_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Author: Jakub Kuderski @kuhar

Date: 2024-06-24

Last Update: 2024-08-22
Last Update: 2024-11-22

## Introduction

Expand Down Expand Up @@ -293,3 +293,124 @@ forms a *clause* that translates to a single data fabric transaction.
> [!TIP]
> For allocations of 4 GB or less, you can implement predicated loads using the
> `buffer` instructions.
## Data-Parallel Primitives and Warp-level Reduction
For cross-lane data sharing, the most straightforward way is LDS. Some lanes
write data to some locations on LDS and other lanes read data from LDS. Besides,
there are several instructions can be used to share data cross lanes within a
wavefront/warp.
Here's a brief introduction of these instructions. Please check out [this
blog](https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/) for
details.
### ds_permute/ds_bpermute
`ds_permute`/`ds_bpermute` instructions use LDS hardware for data sharing but
don't actually write to an LDS location. But it still needs `s_waitcnt`
instruction to determine when data is returned to `dest` VGPR.
Example:
```nasm
ds_bpermute_b32 dest, addr, src [offset:addr_offset]
```
### ds_swizzle

Compared to `ds_bpermute`, the `ds_swizzle` instruction doesn't require an
additional VGPR for offset since it's encoded in the instruction.

`ds_swizzle` is likely to have less address generation instructions required
than `ds_bpermute`.

The cons are:
1. It only supports limited patterns.
2. Similar to `ds_bpermute`, `s_waitcnt` is required to wait for the `dest` VGPR.

Example:
```nasm
ds_swizzle_b32 dest, src offset:ds_pattern
```

### Data-Parallel Primitives, DPP

DPP is a 32-bit instruction modifier appended to the normal VALU instructions.
It allows VALU instructions to access data in neighboring lanes directly, which
means it doesn't need LDS hardware anymore, hence `s_waitcnt` instructions are
**not required**.

Unfortunately, it also supported limited patterns like `ds_swizzle`. And there
are some instructions that can't be modified by DPP.

Example:
```nasm
; Normal VALU instruction.
v_add_f32
; Instruction modified by DPP.
v_add_f32_dpp
```

It's worth mentioning that DPP has different names and syntaxes on different
architectures:
* CDNA: DPP
* RDNA: DPP8/DPP16

For details, please check the [MI300 ISA Reference
Guide](https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf)
and the [RDNA3 ISA Reference
Guide](https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna3-shader-instruction-set-architecture-feb-2023_0.pdf).

### How to use them in MLIR

Each instruction has a corresponding Op in MLIR (except for `ds_permute`, this
one is not implemented at the time of writing):
* `ds_bpermute`: `rocdl.ds_bpermute`
* `ds_swizzle`: `rocdl.ds_swizzle`
* DPP: `rocdl.update.dpp`, `amdgpu.dpp` (a thin wrapper around
`rocdl.update.dpp` with more comprehensive user interface, e.g., replace magic
numbers with enums)

The first 2 are straightforward, while DPP follows a different fashion.

Since DPP is an instruction modifier instead of an instruction itself, there are
tremendous number of combinations of VALU instructions and DPP. To solve that,
`rocdl.update.dpp` and `amdgpu.dpp` are designed to be a wrapper of
`v_mov_b32_dpp` instruction. And it depends on LLVM compiler to fuse it with the
subsequent VALU instruction **with best efforts**.

For example, `v_mov_b32_dpp` + `v_add_f32_e32` might be fused into `v_add_f32_dpp`.

There are plenty of constraints stopping an instruction from being merged. For
example, if either the `bank_mask` or the `row_mask` is not `0xf`, it can't be
fused. You can check the
[GCNDPPCombine::combineDPPMov](https://github.com/llvm/llvm-project/blob/ab51eccf88f5321e7c60591c5546b254b6afab99/llvm/lib/Target/AMDGPU/GCNDPPCombine.cpp#L522)
function to see how it works.

### Comparison

To summarize, there's no free lunch: instruction's expressivity comes at the
expense of performance.

The relative performance of cross-lane instructions is as follows:

DPP > `ds_swizzle` >= `ds_permute` > `ds_bpermute`

while the generality ranking is the reverse:

DPP < `ds_swizzle` < `ds_permute` < `ds_bpermute`

This table presents the approximate instruction latency, collected
experimentally on Fused Softmax kernel with
[rocprofv2](https://github.com/ROCm/rocprofiler?tab=readme-ov-file#plugin-support)
on the MI300 GPU:

| Instructions | MLIR Op | Hardware | latency/#cycles |
| ---------------------- | ---------------------------- | ------------ | --------------- |
| ds_permute/ds_bpermute | rocdl.ds_bpermute | LDS hardware | ~50* |
| ds_swizzle | rocdl.ds_swizzle | LDS hardware | ~50* |
| DPP | rocdl.update.dpp, amdgpu.dpp | VALU | 4~12 |

*: For `ds_permute`/`ds_bpermute` and `ds_swizzle`, the latency includes the
instruction itself and its corresponding `s_waitcnt` instruction.

0 comments on commit 5be4210

Please sign in to comment.