From 1fb7ec307bfd29f99b725250cf715b6c63542429 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 2 Mar 2024 17:03:50 -0800 Subject: [PATCH 1/4] Added apply_chat_template to model --- llama-cpp-2/src/lib.rs | 22 +++++++++++ llama-cpp-2/src/model.rs | 80 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 96 insertions(+), 6 deletions(-) diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 49e333e0..e90d9189 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -207,6 +207,28 @@ pub enum StringToTokenError { CIntConversionError(#[from] std::num::TryFromIntError), } +/// Failed to apply model chat template. +#[derive(Debug, thiserror::Error)] +pub enum NewLlamaChatMessageError { + /// the string contained a null byte and thus could not be converted to a c string. + #[error("{0}")] + NulError(#[from] NulError), +} + +/// Failed to apply model chat template. +#[derive(Debug, thiserror::Error)] +pub enum ApplyChatTemplateError { + /// the buffer was too small. + #[error("The buffer was too small. Please contact a maintainer")] + BuffSizeError, + /// the string contained a null byte and thus could not be converted to a c string. + #[error("{0}")] + NulError(#[from] NulError), + /// the string could not be converted to utf8. + #[error("{0}")] + FromUtf8Error(#[from] FromUtf8Error), +} + /// Get the time in microseconds according to ggml /// /// ``` diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 9f01ac24..112d228d 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams; use crate::token::LlamaToken; use crate::token_type::LlamaTokenType; use crate::{ - ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, - TokenToStringError, + ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, + NewLlamaChatMessageError, StringToTokenError, TokenToStringError, }; pub mod params; @@ -25,6 +25,23 @@ pub struct LlamaModel { pub(crate) model: NonNull, } +/// A Safe wrapper around `llama_chat_message` +#[derive(Debug)] +pub struct LlamaChatMessage { + role: CString, + content: CString, +} + +impl LlamaChatMessage { + /// Create a new `LlamaChatMessage` + pub fn new(role: String, content: String) -> Result { + Ok(Self { + role: CString::new(role)?, + content: CString::new(content)?, + }) + } +} + /// How to determine if we should prepend a bos token to tokens #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AddBos { @@ -312,17 +329,16 @@ impl LlamaModel { /// Get chat template from model. /// /// # Errors - /// + /// /// * If the model has no chat template /// * If the chat template is not a valid [`CString`]. #[allow(clippy::missing_panics_doc)] // we statically know this will not panic as pub fn get_chat_template(&self, buf_size: usize) -> Result { - // longest known template is about 1200 bytes from llama.cpp let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null"); let chat_ptr = chat_temp.into_raw(); let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes"); - + let chat_template: String = unsafe { let ret = llama_cpp_sys_2::llama_model_meta_val_str( self.model.as_ptr(), @@ -337,7 +353,7 @@ impl LlamaModel { debug_assert_eq!(usize::try_from(ret).unwrap(), template.len(), "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case", template.len()); template }; - + Ok(chat_template) } @@ -388,6 +404,58 @@ impl LlamaModel { Ok(LlamaContext::new(self, context, params.embeddings())) } + + /// Apply the models chat template to some messages. + /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + /// + /// # Errors + /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information. + #[tracing::instrument(skip_all)] + pub fn apply_chat_template( + &self, + tmpl: Option, + chat: Vec, + add_ass: bool, + ) -> Result { + // Buffer is twice the length of messages per their recommendation + let message_length = chat.iter().fold(0, |acc, c| { + acc + c.role.to_bytes().len() + c.content.to_bytes().len() + }); + let mut buff: Vec = vec![0_i8; message_length * 2]; + // Build our llama_cpp_sys_2 chat messages + let chat: Vec = chat + .iter() + .map(|c| llama_cpp_sys_2::llama_chat_message { + role: c.role.as_ptr(), + content: c.content.as_ptr(), + }) + .collect(); + // Set the tmpl pointer + let tmpl = tmpl.map(|v| CString::new(v)); + eprintln!("TEMPLATE AGAIN: {:?}", tmpl); + let tmpl_ptr = match tmpl { + Some(str) => str?.as_ptr(), + None => std::ptr::null(), + }; + let formatted_chat = unsafe { + let res = llama_cpp_sys_2::llama_chat_apply_template( + self.model.as_ptr(), + tmpl_ptr, + chat.as_ptr(), + chat.len(), + add_ass, + buff.as_mut_ptr(), + buff.len() as i32, + ); + // This should never happen + if res > buff.len() as i32 { + return Err(ApplyChatTemplateError::BuffSizeError); + } + println!("BUFF: {:?}", buff); + String::from_utf8(buff.iter().filter(|c| **c > 0).map(|&c| c as u8).collect()) + }; + Ok(formatted_chat?) + } } impl Drop for LlamaModel { From cf360fcdf591e48e60a6cfacbb151bdfc5363a3f Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 2 Mar 2024 17:05:28 -0800 Subject: [PATCH 2/4] Cleaned up --- llama-cpp-2/src/model.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 112d228d..b1f0ea47 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -432,7 +432,6 @@ impl LlamaModel { .collect(); // Set the tmpl pointer let tmpl = tmpl.map(|v| CString::new(v)); - eprintln!("TEMPLATE AGAIN: {:?}", tmpl); let tmpl_ptr = match tmpl { Some(str) => str?.as_ptr(), None => std::ptr::null(), @@ -451,7 +450,6 @@ impl LlamaModel { if res > buff.len() as i32 { return Err(ApplyChatTemplateError::BuffSizeError); } - println!("BUFF: {:?}", buff); String::from_utf8(buff.iter().filter(|c| **c > 0).map(|&c| c as u8).collect()) }; Ok(formatted_chat?) From dcbcdd6478e09ca10e808be41306be8e2fceab5d Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 4 Apr 2024 18:42:51 -0700 Subject: [PATCH 3/4] Requested changes --- llama-cpp-2/src/lib.rs | 2 +- llama-cpp-2/src/model.rs | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index e90d9189..95384a93 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -219,7 +219,7 @@ pub enum NewLlamaChatMessageError { #[derive(Debug, thiserror::Error)] pub enum ApplyChatTemplateError { /// the buffer was too small. - #[error("The buffer was too small. Please contact a maintainer")] + #[error("The buffer was too small. Please contact a maintainer and we will update it.")] BuffSizeError, /// the string contained a null byte and thus could not be converted to a c string. #[error("{0}")] diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index b1f0ea47..7da0136c 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -26,7 +26,7 @@ pub struct LlamaModel { } /// A Safe wrapper around `llama_chat_message` -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq, Clone)] pub struct LlamaChatMessage { role: CString, content: CString, @@ -408,6 +408,8 @@ impl LlamaModel { /// Apply the models chat template to some messages. /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template /// + /// `tmpl` of None means to use the default template provided by llama.cpp for the model + /// /// # Errors /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information. #[tracing::instrument(skip_all)] @@ -431,7 +433,7 @@ impl LlamaModel { }) .collect(); // Set the tmpl pointer - let tmpl = tmpl.map(|v| CString::new(v)); + let tmpl = tmpl.map(CString::new); let tmpl_ptr = match tmpl { Some(str) => str?.as_ptr(), None => std::ptr::null(), @@ -446,13 +448,14 @@ impl LlamaModel { buff.as_mut_ptr(), buff.len() as i32, ); - // This should never happen + // A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it + // The error message informs the user to contact a maintainer if res > buff.len() as i32 { return Err(ApplyChatTemplateError::BuffSizeError); } String::from_utf8(buff.iter().filter(|c| **c > 0).map(|&c| c as u8).collect()) - }; - Ok(formatted_chat?) + }?; + Ok(formatted_chat) } } From 6f9fa32b4b4f9eea9b6ea1079a03950cd8623d4d Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 6 Apr 2024 11:42:08 -0700 Subject: [PATCH 4/4] Add pointer cast --- llama-cpp-2/src/model.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 7da0136c..a39e70e1 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -424,6 +424,7 @@ impl LlamaModel { acc + c.role.to_bytes().len() + c.content.to_bytes().len() }); let mut buff: Vec = vec![0_i8; message_length * 2]; + // Build our llama_cpp_sys_2 chat messages let chat: Vec = chat .iter() @@ -445,7 +446,7 @@ impl LlamaModel { chat.as_ptr(), chat.len(), add_ass, - buff.as_mut_ptr(), + buff.as_mut_ptr().cast::(), buff.len() as i32, ); // A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it