Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support continue final message #2733

Merged
merged 10 commits into from
Nov 28, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1",
"role": "assistant"
}
}
],
"created": 1732541189,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native",
"usage": {
"completion_tokens": 30,
"prompt_tokens": 49,
"total_tokens": 79
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds",
"role": "assistant"
}
}
],
"created": 1732541190,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native",
"usage": {
"completion_tokens": 30,
"prompt_tokens": 73,
"total_tokens": 103
}
}
76 changes: 76 additions & 0 deletions integration-tests/models/test_continue_final_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest
import requests


@pytest.fixture(scope="module")
def llama_continue_final_message_handle(launcher):
with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0") as handle:
yield handle


@pytest.fixture(scope="module")
async def llama_continue_final_message(llama_continue_final_message_handle):
await llama_continue_final_message_handle.health(300)
return llama_continue_final_message_handle.client


def test_llama_completion_single_prompt(
llama_continue_final_message, response_snapshot
):
response = requests.post(
f"{llama_continue_final_message.base_url}/v1/chat/completions",
json={
"model": "tgi",
"messages": [
{"role": "system", "content": "system message"},
{"role": "user", "content": "Which is bigger an elephant or a mouse?"},
],
"max_tokens": 30,
"stream": False,
"seed": 1337,
},
headers=llama_continue_final_message.headers,
stream=False,
)
response = response.json()
print(response)
assert len(response["choices"]) == 1
content = response["choices"][0]["message"]["content"]
assert (
content
== "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1"
)
assert response == response_snapshot


def test_llama_completion_single_prompt_continue(
llama_continue_final_message, response_snapshot
):
response = requests.post(
f"{llama_continue_final_message.base_url}/v1/chat/completions",
json={
"model": "tgi",
"messages": [
{"role": "system", "content": "system message"},
{"role": "user", "content": "Which is bigger an elephant or a mouse?"},
{
"role": "assistant",
"content": "the elephant, but have you heard about",
},
],
"max_tokens": 30,
"stream": False,
"seed": 1337,
},
headers=llama_continue_final_message.headers,
stream=False,
)
response = response.json()
print(response)
assert len(response["choices"]) == 1
content = response["choices"][0]["message"]["content"]
assert (
content
== " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds"
)
assert response == response_snapshot
24 changes: 21 additions & 3 deletions router/src/infer/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,34 @@ impl ChatTemplate {
};

let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();

self.template
let final_message = messages.last().cloned();
let mut rendered_template = self
.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools,
})
.map_err(InferError::TemplateError)
.map_err(InferError::TemplateError)?;

// if the last message is from the assistant, continue the generation prompt
rendered_template = match final_message {
Some(msg) if msg.role == "assistant" => {
match rendered_template.rfind(msg.content.as_str()) {
// implementation based on feature in transformers pipeline
// https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418
Some(index) => rendered_template[..index + msg.content.len()]
.trim_end()
.to_string(),
None => rendered_template,
}
}
_ => rendered_template,
};

Ok(rendered_template)
}
}

Expand Down
Loading