From c782a786231d8144c1a2f594d2d46ad1ee63780e Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 7 Nov 2024 18:24:58 -0400 Subject: [PATCH 01/10] feat: support continue_final_message param in chat request --- router/src/infer/chat_template.rs | 96 +++++++++++++++++++++++++++++-- router/src/infer/mod.rs | 9 ++- 2 files changed, 99 insertions(+), 6 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index ceba14a624f..e680d600733 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -2,6 +2,7 @@ use crate::infer::InferError; use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; +use std::collections::HashSet; /// Raise a exception (custom function) used in the chat templates pub(crate) fn raise_exception(err_text: String) -> Result { @@ -14,6 +15,7 @@ pub(crate) struct ChatTemplate { bos_token: Option, eos_token: Option, use_default_tool_template: bool, + variables: HashSet, } impl ChatTemplate { @@ -45,14 +47,22 @@ impl ChatTemplate { bos_token: bos_token.map(|token| token.as_str().to_string()), eos_token: eos_token.map(|token| token.as_str().to_string()), use_default_tool_template, + variables, } } pub(crate) fn apply( &self, + guideline: Option<&str>, + continue_final_message: bool, mut messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { + // check if guideline is expected but not provided + if self.variables.contains("guideline") && guideline.is_none() { + return Err(InferError::MissingTemplateVariable("guideline".to_string())); + } + let tools = match tools_and_prompt { Some((tools, tool_prompt)) => { // check if the `tools` variable is used in the template @@ -75,16 +85,35 @@ impl ChatTemplate { }; let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template + let final_message_content = messages.last().map(|m| m.content.clone()); + let mut rendered_template = self + .template .render(ChatTemplateInputs { + guideline, messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), add_generation_prompt: true, tools, }) - .map_err(InferError::TemplateError) + .map_err(InferError::TemplateError)?; + + if continue_final_message { + // find the last occurrence of the final message in the rendered chat + if let Some(final_message) = final_message_content { + rendered_template = if let Some(index) = rendered_template.rfind(&final_message) { + // implementation based on feature in transformers pipeline + // https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418 + rendered_template[..index + final_message.len()] + .trim_end() + .to_string() + } else { + rendered_template + }; + } + } + + Ok(rendered_template) } } @@ -746,6 +775,19 @@ mod tests { }, target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", }, + ChatTemplateTestItem { + name: "google/shieldgemma-9b", + chat_template: "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + guideline: Some("Do not use offensive language."), + ..Default::default() + }, + target: "You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n\nHuman Question: I'd like to show off how chat templating works!\n\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n", + }, ]; #[allow(unused_variables)] // name is unused @@ -771,6 +813,48 @@ mod tests { } } + #[test] + fn test_chat_template_invalid_with_guideline() { + let ct = ChatTemplate::new( + "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); + + // convert TextMessage to Message + let msgs: Vec = vec![ + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, + Message { + name: None, + role: "assistant".to_string(), + content: MessageContent::SingleText( + "I'm doing great. How can I help you today?".to_string(), + ), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText("Hello, how are you?".to_string()), + }, + ]; + let continue_final_message = false; + + let result = ct.apply(None, continue_final_message, msgs, None); + + match result { + Ok(_) => panic!("Should have failed since no guideline is provided"), + Err(e) => { + assert_eq!(e.to_string(), "Missing template vatiable: guideline") + } + } + } + #[test] fn test_chat_template_with_default_tool_template() { let ct = ChatTemplate::new( @@ -801,9 +885,10 @@ mod tests { ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); + let continue_final_message = false; let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(msgs, tools_and_prompt); + let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } @@ -835,9 +920,10 @@ mod tests { ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); + let continue_final_message = false; let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(msgs, tools_and_prompt); + let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); assert_eq!(result.unwrap(), expected); } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 1351b87e291..41c2ffe8800 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -159,13 +159,20 @@ impl Infer { #[instrument(skip_all)] pub(crate) fn apply_chat_template( &self, + guideline: Option, + continue_final_message: bool, messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages, tools_and_prompt) + .apply( + guideline.as_deref(), + continue_final_message, + messages, + tools_and_prompt, + ) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); From b2ae92e4707cd79fbca16e1c52fcb3eea6f9ff88 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Fri, 8 Nov 2024 19:00:05 +0000 Subject: [PATCH 02/10] feat: add test for continue final message --- .../test_llama_completion_single_prompt.json | 23 ++++++ ...ama_completion_single_prompt_continue.json | 23 ++++++ .../models/test_continue_final_message.py | 81 +++++++++++++++++++ 3 files changed, 127 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json create mode 100644 integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json create mode 100644 integration-tests/models/test_continue_final_message.py diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json new file mode 100644 index 00000000000..a452399e064 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based", + "role": "assistant" + } + } + ], + "created": 1731082056, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", + "usage": { + "completion_tokens": 30, + "prompt_tokens": 57, + "total_tokens": 87 + } +} diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json new file mode 100644 index 00000000000..3e48bc37b86 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system", + "role": "assistant" + } + } + ], + "created": 1731082129, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", + "usage": { + "completion_tokens": 30, + "prompt_tokens": 44, + "total_tokens": 74 + } +} diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py new file mode 100644 index 00000000000..9a8b07ad7a4 --- /dev/null +++ b/integration-tests/models/test_continue_final_message.py @@ -0,0 +1,81 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def llama_continue_final_message_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + num_shard=1, + disable_grammar_support=False, + use_flash_attention=False, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def llama_continue_final_message(llama_continue_final_message_handle): + await llama_continue_final_message_handle.health(300) + return llama_continue_final_message_handle.client + + +def test_llama_completion_single_prompt( + llama_continue_final_message, response_snapshot +): + response = requests.post( + f"{llama_continue_final_message.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + ], + "max_tokens": 30, + "stream": False, + "seed": 1337, + "continue_final_message": False, + }, + headers=llama_continue_final_message.headers, + stream=False, + ) + response = response.json() + print(response) + assert len(response["choices"]) == 1 + content = response["choices"][0]["message"]["content"] + assert ( + content + == "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based" + ) + assert response == response_snapshot + + +def test_llama_completion_single_prompt_continue( + llama_continue_final_message, response_snapshot +): + response = requests.post( + f"{llama_continue_final_message.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + ], + "max_tokens": 30, + "stream": False, + "seed": 1337, + "continue_final_message": True, + }, + headers=llama_continue_final_message.headers, + stream=False, + ) + response = response.json() + print(response) + assert len(response["choices"]) == 1 + content = response["choices"][0]["message"]["content"] + assert ( + content + == ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system" + ) + assert response == response_snapshot From d6280141deabe1d9595e4cb65f299ce19a25b938 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Fri, 8 Nov 2024 20:46:58 +0000 Subject: [PATCH 03/10] fix: bump openapi docs --- docs/openapi.json | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/openapi.json b/docs/openapi.json index 44691e4bba7..02350a569fe 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -987,6 +987,12 @@ "messages" ], "properties": { + "continue_final_message": { + "type": "boolean", + "description": "Whether to continue the final message in the next request.", + "default": "false", + "example": true + }, "frequency_penalty": { "type": "number", "format": "float", From 70066e6d8c5ea4739f79c731f7c46e4b04b5464d Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 19 Nov 2024 21:24:18 +0000 Subject: [PATCH 04/10] fix: remove continue_final_message chat request param --- docs/openapi.json | 6 ---- .../test_llama_completion_single_prompt.json | 8 ++--- ...ama_completion_single_prompt_continue.json | 8 ++--- .../models/test_continue_final_message.py | 16 +++++----- router/src/infer/chat_template.rs | 32 ++++++++----------- router/src/infer/mod.rs | 8 +---- 6 files changed, 31 insertions(+), 47 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 02350a569fe..44691e4bba7 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -987,12 +987,6 @@ "messages" ], "properties": { - "continue_final_message": { - "type": "boolean", - "description": "Whether to continue the final message in the next request.", - "default": "false", - "example": true - }, "frequency_penalty": { "type": "number", "format": "float", diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json index a452399e064..caa00f9994b 100644 --- a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json @@ -5,19 +5,19 @@ "index": 0, "logprobs": null, "message": { - "content": "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based", + "content": "\nGenerate according to: It is an elephant's one year old baby or a mouse's one year old baby. It is", "role": "assistant" } } ], - "created": 1731082056, + "created": 1732050325, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "chat.completion", "system_fingerprint": "2.4.1-dev0-native", "usage": { "completion_tokens": 30, - "prompt_tokens": 57, - "total_tokens": 87 + "prompt_tokens": 37, + "total_tokens": 67 } } diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json index 3e48bc37b86..f880dd74ce2 100644 --- a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json @@ -5,19 +5,19 @@ "index": 0, "logprobs": null, "message": { - "content": ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system", + "content": " Shere Kohan's fantastic exploits? written by David Shimomura & illustrated by Sarah Stern\n\nTitle: Elephant", "role": "assistant" } } ], - "created": 1731082129, + "created": 1732050326, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "chat.completion", "system_fingerprint": "2.4.1-dev0-native", "usage": { "completion_tokens": 30, - "prompt_tokens": 44, - "total_tokens": 74 + "prompt_tokens": 61, + "total_tokens": 91 } } diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py index 9a8b07ad7a4..ea3d83c9cbb 100644 --- a/integration-tests/models/test_continue_final_message.py +++ b/integration-tests/models/test_continue_final_message.py @@ -28,13 +28,11 @@ def test_llama_completion_single_prompt( "model": "tgi", "messages": [ {"role": "system", "content": "system message"}, - {"role": "user", "content": "user message"}, - {"role": "assistant", "content": "assistant message"}, + {"role": "user", "content": "Which is bigger an elephant or a mouse?"}, ], "max_tokens": 30, "stream": False, "seed": 1337, - "continue_final_message": False, }, headers=llama_continue_final_message.headers, stream=False, @@ -45,7 +43,7 @@ def test_llama_completion_single_prompt( content = response["choices"][0]["message"]["content"] assert ( content - == "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based" + == "\nGenerate according to: It is an elephant's one year old baby or a mouse's one year old baby. It is" ) assert response == response_snapshot @@ -59,13 +57,15 @@ def test_llama_completion_single_prompt_continue( "model": "tgi", "messages": [ {"role": "system", "content": "system message"}, - {"role": "user", "content": "user message"}, - {"role": "assistant", "content": "assistant message"}, + {"role": "user", "content": "Which is bigger an elephant or a mouse?"}, + { + "role": "assistant", + "content": "the elephant, but have you heard about", + }, ], "max_tokens": 30, "stream": False, "seed": 1337, - "continue_final_message": True, }, headers=llama_continue_final_message.headers, stream=False, @@ -76,6 +76,6 @@ def test_llama_completion_single_prompt_continue( content = response["choices"][0]["message"]["content"] assert ( content - == ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system" + == " Shere Kohan's fantastic exploits? written by David Shimomura & illustrated by Sarah Stern\n\nTitle: Elephant" ) assert response == response_snapshot diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index e680d600733..74f38ddaaf7 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -54,7 +54,6 @@ impl ChatTemplate { pub(crate) fn apply( &self, guideline: Option<&str>, - continue_final_message: bool, mut messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { @@ -85,7 +84,7 @@ impl ChatTemplate { }; let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - let final_message_content = messages.last().map(|m| m.content.clone()); + let final_message = messages.last().cloned(); let mut rendered_template = self .template .render(ChatTemplateInputs { @@ -98,20 +97,20 @@ impl ChatTemplate { }) .map_err(InferError::TemplateError)?; - if continue_final_message { - // find the last occurrence of the final message in the rendered chat - if let Some(final_message) = final_message_content { - rendered_template = if let Some(index) = rendered_template.rfind(&final_message) { + // if the last message is from the assistant, continue the generation prompt + rendered_template = match final_message { + Some(msg) if msg.role == "assistant" => { + match rendered_template.rfind(msg.content.as_str()) { // implementation based on feature in transformers pipeline // https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418 - rendered_template[..index + final_message.len()] + Some(index) => rendered_template[..index + msg.content.len()] .trim_end() - .to_string() - } else { - rendered_template - }; + .to_string(), + None => rendered_template, + } } - } + _ => rendered_template, + }; Ok(rendered_template) } @@ -843,9 +842,8 @@ mod tests { content: MessageContent::SingleText("Hello, how are you?".to_string()), }, ]; - let continue_final_message = false; - let result = ct.apply(None, continue_final_message, msgs, None); + let result = ct.apply(None, msgs, None); match result { Ok(_) => panic!("Should have failed since no guideline is provided"), @@ -885,10 +883,9 @@ mod tests { ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); - let continue_final_message = false; let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); + let result = ct.apply(None, msgs, tools_and_prompt); let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } @@ -920,10 +917,9 @@ mod tests { ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); - let continue_final_message = false; let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); + let result = ct.apply(None, msgs, tools_and_prompt); let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); assert_eq!(result.unwrap(), expected); } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 41c2ffe8800..d3d6bc597ba 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -160,19 +160,13 @@ impl Infer { pub(crate) fn apply_chat_template( &self, guideline: Option, - continue_final_message: bool, messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply( - guideline.as_deref(), - continue_final_message, - messages, - tools_and_prompt, - ) + .apply(guideline.as_deref(), messages, tools_and_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); From 7486d930f8bbba44bece9fac1fbd927d230f84f5 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 21 Nov 2024 09:09:39 -0500 Subject: [PATCH 05/10] fix: remove unneeded launcher args in continue test --- integration-tests/models/test_continue_final_message.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py index ea3d83c9cbb..88d944337a3 100644 --- a/integration-tests/models/test_continue_final_message.py +++ b/integration-tests/models/test_continue_final_message.py @@ -4,12 +4,7 @@ @pytest.fixture(scope="module") def llama_continue_final_message_handle(launcher): - with launcher( - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - num_shard=1, - disable_grammar_support=False, - use_flash_attention=False, - ) as handle: + with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0") as handle: yield handle From 4069955e44aa75acb370e8c53cbc238109b1ac3e Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 21 Nov 2024 12:04:24 -0500 Subject: [PATCH 06/10] fix: bump test output --- .../test_llama_completion_single_prompt_continue.json | 2 +- integration-tests/models/test_continue_final_message.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json index f880dd74ce2..8f782694be0 100644 --- a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": " Shere Kohan's fantastic exploits? written by David Shimomura & illustrated by Sarah Stern\n\nTitle: Elephant", + "content": " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds", "role": "assistant" } } diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py index 88d944337a3..2fb99273222 100644 --- a/integration-tests/models/test_continue_final_message.py +++ b/integration-tests/models/test_continue_final_message.py @@ -71,6 +71,6 @@ def test_llama_completion_single_prompt_continue( content = response["choices"][0]["message"]["content"] assert ( content - == " Shere Kohan's fantastic exploits? written by David Shimomura & illustrated by Sarah Stern\n\nTitle: Elephant" + == " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds" ) assert response == response_snapshot From 8770b39c202070f568c9cd0ca18531cb3016643a Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 22 Nov 2024 13:50:30 -0500 Subject: [PATCH 07/10] fix: remove accidentally included guideline from rebase --- router/src/infer/chat_template.rs | 17 +++-------------- router/src/infer/mod.rs | 3 +-- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 74f38ddaaf7..9caf15b77c8 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -2,7 +2,6 @@ use crate::infer::InferError; use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; -use std::collections::HashSet; /// Raise a exception (custom function) used in the chat templates pub(crate) fn raise_exception(err_text: String) -> Result { @@ -15,7 +14,6 @@ pub(crate) struct ChatTemplate { bos_token: Option, eos_token: Option, use_default_tool_template: bool, - variables: HashSet, } impl ChatTemplate { @@ -47,21 +45,14 @@ impl ChatTemplate { bos_token: bos_token.map(|token| token.as_str().to_string()), eos_token: eos_token.map(|token| token.as_str().to_string()), use_default_tool_template, - variables, } } pub(crate) fn apply( &self, - guideline: Option<&str>, mut messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { - // check if guideline is expected but not provided - if self.variables.contains("guideline") && guideline.is_none() { - return Err(InferError::MissingTemplateVariable("guideline".to_string())); - } - let tools = match tools_and_prompt { Some((tools, tool_prompt)) => { // check if the `tools` variable is used in the template @@ -88,7 +79,6 @@ impl ChatTemplate { let mut rendered_template = self .template .render(ChatTemplateInputs { - guideline, messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), @@ -782,7 +772,6 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), - guideline: Some("Do not use offensive language."), ..Default::default() }, target: "You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n\nHuman Question: I'd like to show off how chat templating works!\n\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n", @@ -843,7 +832,7 @@ mod tests { }, ]; - let result = ct.apply(None, msgs, None); + let result = ct.apply(msgs, None); match result { Ok(_) => panic!("Should have failed since no guideline is provided"), @@ -885,7 +874,7 @@ mod tests { let tools: Vec = serde_json::from_str(&tools_string).unwrap(); let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(None, msgs, tools_and_prompt); + let result = ct.apply(msgs, tools_and_prompt); let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } @@ -919,7 +908,7 @@ mod tests { let tools: Vec = serde_json::from_str(&tools_string).unwrap(); let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(None, msgs, tools_and_prompt); + let result = ct.apply(msgs, tools_and_prompt); let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); assert_eq!(result.unwrap(), expected); } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index d3d6bc597ba..1351b87e291 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -159,14 +159,13 @@ impl Infer { #[instrument(skip_all)] pub(crate) fn apply_chat_template( &self, - guideline: Option, messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(guideline.as_deref(), messages, tools_and_prompt) + .apply(messages, tools_and_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); From 13a75acd761e978238ec3fc386e57c29b5144591 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 22 Nov 2024 14:12:18 -0500 Subject: [PATCH 08/10] fix: remove guideline tests --- router/src/infer/chat_template.rs | 53 ------------------------------- 1 file changed, 53 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 9caf15b77c8..69a09046ad9 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -764,18 +764,6 @@ mod tests { }, target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", }, - ChatTemplateTestItem { - name: "google/shieldgemma-9b", - chat_template: "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n", - input: ChatTemplateInputs { - messages: example_chat_with_system.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n\nHuman Question: I'd like to show off how chat templating works!\n\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n", - }, ]; #[allow(unused_variables)] // name is unused @@ -801,47 +789,6 @@ mod tests { } } - #[test] - fn test_chat_template_invalid_with_guideline() { - let ct = ChatTemplate::new( - "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(), - Some(TokenizerConfigToken::String("".to_string())), - Some(TokenizerConfigToken::String("".to_string())), - ); - - // convert TextMessage to Message - let msgs: Vec = vec![ - Message { - name: None, - role: "user".to_string(), - content: MessageContent::SingleText( - "I'd like to show off how chat templating works!".to_string(), - ), - }, - Message { - name: None, - role: "assistant".to_string(), - content: MessageContent::SingleText( - "I'm doing great. How can I help you today?".to_string(), - ), - }, - Message { - name: None, - role: "user".to_string(), - content: MessageContent::SingleText("Hello, how are you?".to_string()), - }, - ]; - - let result = ct.apply(msgs, None); - - match result { - Ok(_) => panic!("Should have failed since no guideline is provided"), - Err(e) => { - assert_eq!(e.to_string(), "Missing template vatiable: guideline") - } - } - } - #[test] fn test_chat_template_with_default_tool_template() { let ct = ChatTemplate::new( From 594a6a7c225e5b0c8a7346d7f2dae2fa8673cd04 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 25 Nov 2024 08:33:58 -0500 Subject: [PATCH 09/10] fix: adjust continuation tests expected text --- .../test_llama_completion_single_prompt.json | 8 ++++---- .../test_llama_completion_single_prompt_continue.json | 8 ++++---- integration-tests/models/test_continue_final_message.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json index caa00f9994b..0bea487c350 100644 --- a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json @@ -5,19 +5,19 @@ "index": 0, "logprobs": null, "message": { - "content": "\nGenerate according to: It is an elephant's one year old baby or a mouse's one year old baby. It is", + "content": "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1", "role": "assistant" } } ], - "created": 1732050325, + "created": 1732541189, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "chat.completion", "system_fingerprint": "2.4.1-dev0-native", "usage": { "completion_tokens": 30, - "prompt_tokens": 37, - "total_tokens": 67 + "prompt_tokens": 49, + "total_tokens": 79 } } diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json index 8f782694be0..bd5cdd7ce87 100644 --- a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json @@ -5,19 +5,19 @@ "index": 0, "logprobs": null, "message": { - "content": " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds", + "content": "the mouse is much smaller than an elephant. The average elephant weight is around 6,500 lbs (3,", "role": "assistant" } } ], - "created": 1732050326, + "created": 1732541190, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "chat.completion", "system_fingerprint": "2.4.1-dev0-native", "usage": { "completion_tokens": 30, - "prompt_tokens": 61, - "total_tokens": 91 + "prompt_tokens": 73, + "total_tokens": 103 } } diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py index 2fb99273222..24ae720fdbf 100644 --- a/integration-tests/models/test_continue_final_message.py +++ b/integration-tests/models/test_continue_final_message.py @@ -38,7 +38,7 @@ def test_llama_completion_single_prompt( content = response["choices"][0]["message"]["content"] assert ( content - == "\nGenerate according to: It is an elephant's one year old baby or a mouse's one year old baby. It is" + == "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1" ) assert response == response_snapshot @@ -71,6 +71,6 @@ def test_llama_completion_single_prompt_continue( content = response["choices"][0]["message"]["content"] assert ( content - == " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds" + == "the mouse is much smaller than an elephant. The average elephant weight is around 6,500 lbs (3," ) assert response == response_snapshot From 8505341931670185ff950a21137436ff23c78d08 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 25 Nov 2024 10:10:33 -0500 Subject: [PATCH 10/10] fix: replace expected output for continue test --- .../test_llama_completion_single_prompt_continue.json | 2 +- integration-tests/models/test_continue_final_message.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json index bd5cdd7ce87..100fb3385e4 100644 --- a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "the mouse is much smaller than an elephant. The average elephant weight is around 6,500 lbs (3,", + "content": " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds", "role": "assistant" } } diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py index 24ae720fdbf..01c86dcd104 100644 --- a/integration-tests/models/test_continue_final_message.py +++ b/integration-tests/models/test_continue_final_message.py @@ -71,6 +71,6 @@ def test_llama_completion_single_prompt_continue( content = response["choices"][0]["message"]["content"] assert ( content - == "the mouse is much smaller than an elephant. The average elephant weight is around 6,500 lbs (3," + == " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds" ) assert response == response_snapshot