Skip to content

Commit

Permalink
Upgrade to latest vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
wongjingping committed Nov 1, 2024
1 parent 2b25e9c commit 941ddaf
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: psf/black@stable
test:
runs-on: ubuntu-latest
needs: lint
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'
python-version: '3.10'
cache: 'pip'
- name: Install pip dependencies
run: |
pip install --upgrade pip setuptools
pip install -r requirements.txt
pip install pytest
- name: Download spaCy model
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ argparse
func_timeout
mistralai
mysql-connector-python
numpy
openai>=1.1.0
pandas
pandas-gbq
Expand All @@ -15,7 +16,7 @@ sentence-transformers
snowflake-connector-python
spacy
sqlalchemy
tiktoken==0.7.0
tiktoken
together
torch
tqdm
Expand Down
1 change: 0 additions & 1 deletion run_model_cot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ for model_name in "${model_names[@]}"; do
--api_url "http://localhost:${PORT}/generate" \
--api_type "vllm" \
-p 10 \
--cot_table_alias "prealias" \
--logprobs
# finally, kill the api server
pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}"
Expand Down
21 changes: 18 additions & 3 deletions utils/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,32 @@ async def generate(request: Request) -> Response:
sql_lora_path = request_dict.pop("sql_lora_path", None)
request_dict.pop("sql_lora_name", None)
lora_request = (
LoRARequest("sql_adapter", 1, sql_lora_path) if sql_lora_path else None
LoRARequest(lora_name="sql_adapter", lora_int_id=1, lora_path=sql_lora_path)
if sql_lora_path
else None
)
if vllm_version >= "0.6.2":
# remove use_beam_search if present as it's no longer supported
# see https://github.com/vllm-project/vllm/releases/tag/v0.6.2
if "use_beam_search" in request_dict:
request_dict.pop("use_beam_search")
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
tokenizer = await engine.get_tokenizer()
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
# print(f"prompt_token_ids: {prompt_token_ids}")
if prompt_token_ids[0] != tokenizer.bos_token_id:
prompt_token_ids = [tokenizer.bos_token_id] + prompt_token_ids

if vllm_version >= "0.4.2":
if vllm_version >= "0.6.3":
from vllm import TokensPrompt

results_generator = engine.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
)
elif vllm_version >= "0.4.2":
results_generator = engine.generate(
inputs={"prompt_token_ids": prompt_token_ids},
sampling_params=sampling_params,
Expand Down

0 comments on commit 941ddaf

Please sign in to comment.