From fc3c3b5fdd53eefc8970e256d9a42ed2bfc324a3 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 4 Dec 2024 10:28:09 -0600 Subject: [PATCH 01/16] Add sampling API back to LlamaTokenDataArray --- llama-cpp-2/src/sampling.rs | 87 +++++++- llama-cpp-2/src/token/data_array.rs | 334 +++++++++++++++++++++++++++- llama-cpp-sys-2/build.rs | 13 +- 3 files changed, 422 insertions(+), 12 deletions(-) diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 7181e149..89b9bc1c 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -7,6 +7,7 @@ use std::ptr::NonNull; use crate::context::LlamaContext; use crate::model::LlamaModel; +use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; use crate::LlamaSamplerError; @@ -132,11 +133,22 @@ impl LlamaSampler { self } + /// XTC sampling as described in . + #[must_use] + #[allow(unused_mut)] + pub fn add_xtc(mut self, p: f32, t: f32, min_keep: usize, seed: u32) -> Self { + unsafe { + let xtc_sampler = llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), xtc_sampler); + } + + self + } + /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. /// /// # Arguments /// - /// * `candidates` - A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// * `m` - The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. @@ -157,7 +169,6 @@ impl LlamaSampler { /// /// # Arguments /// - /// * `candidates` - A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. @@ -231,6 +242,74 @@ impl LlamaSampler { self } + /// Adds penalties to the sampler. This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. + #[allow(unused_mut)] + #[must_use] + pub fn add_penalties_simple( + mut self, + model: &LlamaModel, + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + ) -> Self { + self.add_penalties( + model.n_vocab(), + model.token_eos().0, + model.token_nl().0, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + false, + true, + ) + } + + /// Adds DRY repetition penalty to the sampler. + /// + /// DRY sampler, designed by p-e-w, as described in: , porting Koboldcpp implementation authored by pi6am: + #[allow(unused_mut)] + #[must_use] + pub fn add_dry( + mut self, + model: &LlamaModel, + dry_multiplier: f32, + dry_base: f32, + dry_allowed_length: i32, + dry_penalty_last_n: i32, + seq_breakers: &[impl AsRef<[u8]>], + ) -> Self { + let seq_breakers: Vec = seq_breakers + .iter() + .map(|s| { + let bytes = s.as_ref(); + let null_byte = bytes.iter().position(|b| *b == 0).unwrap_or(bytes.len()); + CString::new(&bytes[..null_byte]).expect("Failed to slice away null bytes!") + }) + .collect(); + + let mut seq_breaker_pointers: Vec<*const i8> = + seq_breakers.iter().map(|s| s.as_ptr()).collect(); + + unsafe { + // Memory safety: llama_sampler_init_dry does not hold a reference to + // seq_breaker_pointers, so this will not UAF in future operations. + let dry_sampler = llama_cpp_sys_2::llama_sampler_init_dry( + model.model.as_ptr(), + dry_multiplier, + dry_base, + dry_allowed_length, + dry_penalty_last_n, + seq_breaker_pointers.as_mut_ptr(), + seq_breaker_pointers.len(), + ); + llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), dry_sampler); + } + + self + } + /// Sample and accept a token from the idx-th output of the last evaluation #[must_use] pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken { @@ -241,6 +320,10 @@ impl LlamaSampler { LlamaToken(token) } + pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) { + unsafe { data_array.apply_sampler(self.sampler.as_ptr()) } + } + /// Accepts a token from the sampler, possibly updating the internal state of certain samplers (e.g. grammar, repetition, etc.) pub fn accept(&mut self, token: LlamaToken) { unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler.as_ptr(), token.0) } diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index d9693049..01b24329 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -1,5 +1,9 @@ //! an rusty equivalent of `llama_token_data`. -use crate::token::data::LlamaTokenData; +use std::{ffi::CString, ptr}; + +use crate::{model::LlamaModel, token::data::LlamaTokenData}; + +use super::LlamaToken; /// a safe wrapper around `llama_token_data_array`. #[derive(Debug, Clone, PartialEq)] @@ -7,12 +11,14 @@ use crate::token::data::LlamaTokenData; pub struct LlamaTokenDataArray { /// the underlying data pub data: Vec, + /// the selected token + pub selected: i64, /// is the data sorted? pub sorted: bool, } impl LlamaTokenDataArray { - /// Create a new `LlamaTokenDataArray` from a vector and weather or not the data is sorted. + /// Create a new `LlamaTokenDataArray` from a vector and whether or not the data is sorted. /// /// ``` /// # use llama_cpp_2::token::data::LlamaTokenData; @@ -27,10 +33,14 @@ impl LlamaTokenDataArray { /// ``` #[must_use] pub fn new(data: Vec, sorted: bool) -> Self { - Self { data, sorted } + Self { + data, + selected: -1, + sorted, + } } - /// Create a new `LlamaTokenDataArray` from an iterator and weather or not the data is sorted. + /// Create a new `LlamaTokenDataArray` from an iterator and whether or not the data is sorted. /// ``` /// # use llama_cpp_2::token::data::LlamaTokenData; /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; @@ -47,4 +57,320 @@ impl LlamaTokenDataArray { { Self::new(data.into_iter().collect(), sorted) } + + #[must_use] + pub fn selected_token(&self) -> Option { + self.data + .get(usize::try_from(self.selected).ok()?) + .map(LlamaTokenData::id) + } +} + +impl LlamaTokenDataArray { + /// Modify the underlying data as a `llama_token_data_array`. and reconstruct the `LlamaTokenDataArray`. + /// + /// # Panics + /// + /// Panics if some of the safety conditions are not met. (we cannot check all of them at runtime so breaking them is UB) + /// + /// SAFETY: + /// [modify] cannot change the data pointer. + /// if the data is not sorted, sorted must be false. + /// the size of the data can only decrease (i.e you cannot add new elements). + pub(crate) unsafe fn modify_as_c_llama_token_data_array( + &mut self, + modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, + ) -> T { + let size = self.data.len(); + let data = self.data.as_mut_ptr().cast(); + let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array { + data, + size, + selected: self.selected, + sorted: self.sorted, + }; + let result = modify(&mut c_llama_token_data_array); + assert!( + ptr::eq(data, c_llama_token_data_array.data), + "data pointer changed" + ); + assert!(c_llama_token_data_array.size <= size, "size increased"); + self.data.set_len(c_llama_token_data_array.size); + self.sorted = c_llama_token_data_array.sorted; + self.selected = c_llama_token_data_array.selected; + result + } + + pub(crate) unsafe fn apply_sampler(&mut self, sampler: *mut llama_cpp_sys_2::llama_sampler) { + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sampler_apply(sampler, c_llama_token_data_array); + }) + } + + pub(crate) unsafe fn apply_and_free_sampler( + &mut self, + sampler_fn: impl FnOnce() -> *mut llama_cpp_sys_2::llama_sampler, + ) { + let sampler = sampler_fn(); + self.apply_sampler(sampler); + llama_cpp_sys_2::llama_sampler_free(sampler); + } + + /// Modify the logits of [`Self`] in place using temperature sampling. + /// + /// # Example + /// + /// ```rust + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0) + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// + /// candidates.sample_temp(0.5); + /// + /// assert_eq!(candidates.data[0].logit(), 0.2); + /// assert_eq!(candidates.data[1].logit(), 0.4); + /// assert_eq!(candidates.data[2].logit(), 1.4); + /// ``` + pub fn sample_temp(&mut self, temperature: f32) { + unsafe { + self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_temp(temperature)); + } + } + + /// Dynamic temperature implementation (a.k.a. entropy) described in the paper . + pub fn sample_temp_ext(&mut self, t: f32, delta: f32, exponent: f32) { + unsafe { + self.apply_and_free_sampler(|| { + llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) + }); + } + } + + /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) + pub fn sample_top_k(&mut self, k: i32) { + unsafe { + self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_top_k(k)); + } + } + + /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). + /// + /// # Example + /// + /// ```rust + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_typical(0.5, 1); + /// ``` + pub fn sample_typical(&mut self, p: f32, min_keep: usize) { + unsafe { + self.apply_and_free_sampler(|| { + llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) + }); + } + } + + /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) + /// + /// # Example + /// + /// ```rust + /// + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_top_p(0.5, 1); + /// + /// assert_eq!(candidates.data.len(), 2); + /// assert_eq!(candidates.data[0].id(), LlamaToken::new(2)); + /// assert_eq!(candidates.data[1].id(), LlamaToken::new(1)); + /// ``` + pub fn sample_top_p(&mut self, p: f32, min_keep: usize) { + unsafe { + self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep)); + } + } + + /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) + /// + /// # Example + /// + /// ``` + /// # use llama_cpp_2::token::data::LlamaTokenData; + /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; + /// # use llama_cpp_2::token::LlamaToken; + /// + /// let candidates = vec![ + /// LlamaTokenData::new(LlamaToken::new(4), -2., 0.0), + /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), + /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), + /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), + /// ]; + /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); + /// candidates.sample_min_p(0.1, 1); + /// + /// assert_eq!(candidates.data.len(), 3); + /// ``` + pub fn sample_min_p(&mut self, p: f32, min_keep: usize) { + unsafe { + self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep)); + } + } + + /// XTC sampling as described in . + pub fn sample_xtc(&mut self, p: f32, t: f32, min_keep: usize, seed: u32) { + unsafe { + self.apply_and_free_sampler(|| { + llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) + }); + } + } + + /// This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. + #[allow(clippy::too_many_arguments)] + pub fn sample_penalties( + &mut self, + tokens: &[LlamaToken], + n_vocab: i32, + special_eos_id: i32, + linefeed_id: i32, + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + penalize_nl: bool, + ignore_eos: bool, + ) { + unsafe { + self.apply_and_free_sampler(|| { + let sampler = llama_cpp_sys_2::llama_sampler_init_penalties( + n_vocab, + special_eos_id, + linefeed_id, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + penalize_nl, + ignore_eos, + ); + + for token in tokens { + llama_cpp_sys_2::llama_sampler_accept(sampler, token.0); + } + + sampler + }); + } + } + + /// This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. + pub fn sample_penalties_simple( + &mut self, + tokens: &[LlamaToken], + model: &LlamaModel, + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + ) { + self.sample_penalties( + tokens, + model.n_vocab(), + model.token_eos().0, + model.token_nl().0, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + false, + true, + ); + } + + /// DRY sampler, designed by p-e-w, as described in: , porting Koboldcpp implementation authored by pi6am: + #[allow(clippy::too_many_arguments)] + pub fn sample_dry( + &mut self, + tokens: &[LlamaToken], + model: &LlamaModel, + dry_multiplier: f32, + dry_base: f32, + dry_allowed_length: i32, + dry_penalty_last_n: i32, + seq_breakers: &[impl AsRef<[u8]>], + ) { + let seq_breakers: Vec = seq_breakers + .iter() + .map(|s| { + let bytes = s.as_ref(); + let null_byte = bytes.iter().position(|b| *b == 0).unwrap_or(bytes.len()); + CString::new(&bytes[..null_byte]).expect("Failed to slice away null bytes!") + }) + .collect(); + + let mut seq_breaker_pointers: Vec<*const i8> = + seq_breakers.iter().map(|s| s.as_ptr()).collect(); + + unsafe { + self.apply_and_free_sampler(|| { + let sampler = llama_cpp_sys_2::llama_sampler_init_dry( + model.model.as_ptr(), + dry_multiplier, + dry_base, + dry_allowed_length, + dry_penalty_last_n, + seq_breaker_pointers.as_mut_ptr(), + seq_breaker_pointers.len(), + ); + + for token in tokens { + llama_cpp_sys_2::llama_sampler_accept(sampler, token.0); + } + + sampler + }); + } + } + + /// Randomly selects a token from the candidates based on their probabilities. + pub fn sample_token(&mut self, seed: u32) -> LlamaToken { + unsafe { + self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_dist(seed)); + } + self.selected_token() + .expect("Dist sampler failed to select a token!") + } + + /// Selects the token with the highest probability. + pub fn sample_token_greedy(&mut self) -> LlamaToken { + unsafe { + self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_greedy()); + } + self.selected_token() + .expect("Greedy sampler failed to select a token!") + } } diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index fe3d8aa9..f5473769 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -143,7 +143,6 @@ fn macos_link_search_path() -> Option { } fn main() { - let target = env::var("TARGET").unwrap(); let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); @@ -196,7 +195,6 @@ fn main() { .generate() .expect("Failed to generate bindings"); - // Write the generated bindings to an output file let bindings_path = out_dir.join("bindings.rs"); bindings @@ -231,12 +229,12 @@ fn main() { if cfg!(windows) { config.static_crt(static_crt); } - if cfg!(feature = "vulkan") { config.define("GGML_VULKAN", "ON"); if cfg!(windows) { - let vulkan_path = env::var("VULKAN_SDK").expect("Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set"); + let vulkan_path = env::var("VULKAN_SDK") + .expect("Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set"); let vulkan_lib_path = Path::new(&vulkan_path).join("Lib"); println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); println!("cargo:rustc-link-lib=vulkan-1"); @@ -265,7 +263,10 @@ fn main() { // Search paths println!("cargo:rustc-link-search={}", out_dir.join("lib").display()); - println!("cargo:rustc-link-search={}", out_dir.join("lib64").display()); + println!( + "cargo:rustc-link-search={}", + out_dir.join("lib64").display() + ); println!("cargo:rustc-link-search={}", build_dir.display()); // Link libraries @@ -332,7 +333,7 @@ fn main() { debug_log!("HARD LINK {} TO {}", asset.display(), dst.display()); if !dst.exists() { std::fs::hard_link(asset.clone(), dst).unwrap(); - } + } // Copy DLLs to examples as well if target_dir.join("examples").exists() { From 25c8e1d0ca94b1c5fbdfd5afefbdc6ae8fc14c92 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 4 Dec 2024 10:45:28 -0600 Subject: [PATCH 02/16] Add convience methods for getting LlamaTokenDataArrays from LlamaContexts --- llama-cpp-2/src/context.rs | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index cdebb88a..549a1559 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -9,6 +9,7 @@ use crate::llama_batch::LlamaBatch; use crate::model::{LlamaLoraAdapter, LlamaModel}; use crate::timing::LlamaTimings; use crate::token::data::LlamaTokenData; +use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; use crate::{ DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError, @@ -202,6 +203,27 @@ impl<'model> LlamaContext<'model> { }) } + /// Get the token data array for the last token in the context. + /// + /// This is a convience method that implements: + /// ```no_run + /// LlamaTokenDataArray::from_iter( + /// self.candidates(), + /// false, + /// ) + /// ``` + /// + /// # Panics + /// + /// - underlying logits data is null + #[must_use] + pub fn token_data_array(&self) -> LlamaTokenDataArray { + LlamaTokenDataArray::from_iter( + self.candidates(), + false, + ) + } + /// Token logits obtained from the last call to `decode()`. /// The logits for which `batch.logits[i] != 0` are stored contiguously /// in the order they have appeared in the batch. @@ -217,6 +239,7 @@ impl<'model> LlamaContext<'model> { /// /// - `n_vocab` does not fit into a usize /// - token data returned is null + #[must_use] pub fn get_logits(&self) -> &[f32] { let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) }; assert!(!data.is_null(), "logits data for last token is null"); @@ -237,6 +260,27 @@ impl<'model> LlamaContext<'model> { }) } + /// Get the token data array for the ith token in the context. + /// + /// This is a convience method that implements: + /// ```no_run + /// LlamaTokenDataArray::from_iter( + /// self.candidates_ith(i), + /// false, + /// ) + /// ``` + /// + /// # Panics + /// + /// - logit `i` is not initialized. + #[must_use] + pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray { + LlamaTokenDataArray::from_iter( + self.candidates_ith(i), + false, + ) + } + /// Get the logits for the ith token in the context. /// /// # Panics From d61858a0e41795edd43ee5192d4aede6d78505d4 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Wed, 4 Dec 2024 10:46:50 -0600 Subject: [PATCH 03/16] Run cargo fmt --- llama-cpp-2/src/context.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 549a1559..fd8b5feb 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -218,10 +218,7 @@ impl<'model> LlamaContext<'model> { /// - underlying logits data is null #[must_use] pub fn token_data_array(&self) -> LlamaTokenDataArray { - LlamaTokenDataArray::from_iter( - self.candidates(), - false, - ) + LlamaTokenDataArray::from_iter(self.candidates(), false) } /// Token logits obtained from the last call to `decode()`. @@ -275,10 +272,7 @@ impl<'model> LlamaContext<'model> { /// - logit `i` is not initialized. #[must_use] pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray { - LlamaTokenDataArray::from_iter( - self.candidates_ith(i), - false, - ) + LlamaTokenDataArray::from_iter(self.candidates_ith(i), false) } /// Get the logits for the ith token in the context. From 6d3dec96e58994ec10144d2a02e42cb4a91919f8 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Thu, 5 Dec 2024 12:47:41 -0600 Subject: [PATCH 04/16] Small documentation improvement --- llama-cpp-2/src/context.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index fd8b5feb..d7078dd7 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -207,10 +207,7 @@ impl<'model> LlamaContext<'model> { /// /// This is a convience method that implements: /// ```no_run - /// LlamaTokenDataArray::from_iter( - /// self.candidates(), - /// false, - /// ) + /// LlamaTokenDataArray::from_iter(ctx.candidates(), false) /// ``` /// /// # Panics @@ -261,10 +258,7 @@ impl<'model> LlamaContext<'model> { /// /// This is a convience method that implements: /// ```no_run - /// LlamaTokenDataArray::from_iter( - /// self.candidates_ith(i), - /// false, - /// ) + /// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false) /// ``` /// /// # Panics From ca071709f4c77b911328b1d3509cd3d8026e4d09 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sat, 7 Dec 2024 10:39:15 -0600 Subject: [PATCH 05/16] Make LlamaTokenDataArray::selected an Option --- llama-cpp-2/src/context.rs | 4 ++-- llama-cpp-2/src/token/data_array.rs | 31 +++++++++++++++++++---------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index d7078dd7..8946da2b 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -206,7 +206,7 @@ impl<'model> LlamaContext<'model> { /// Get the token data array for the last token in the context. /// /// This is a convience method that implements: - /// ```no_run + /// ```ignore /// LlamaTokenDataArray::from_iter(ctx.candidates(), false) /// ``` /// @@ -257,7 +257,7 @@ impl<'model> LlamaContext<'model> { /// Get the token data array for the ith token in the context. /// /// This is a convience method that implements: - /// ```no_run + /// ```ignore /// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false) /// ``` /// diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 01b24329..090c866e 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -1,4 +1,4 @@ -//! an rusty equivalent of `llama_token_data`. +//! an rusty equivalent of `llama_token_data_array`. use std::{ffi::CString, ptr}; use crate::{model::LlamaModel, token::data::LlamaTokenData}; @@ -11,8 +11,8 @@ use super::LlamaToken; pub struct LlamaTokenDataArray { /// the underlying data pub data: Vec, - /// the selected token - pub selected: i64, + /// the index of the selected token in ``data`` + pub selected: Option, /// is the data sorted? pub sorted: bool, } @@ -35,7 +35,7 @@ impl LlamaTokenDataArray { pub fn new(data: Vec, sorted: bool) -> Self { Self { data, - selected: -1, + selected: None, sorted, } } @@ -60,9 +60,7 @@ impl LlamaTokenDataArray { #[must_use] pub fn selected_token(&self) -> Option { - self.data - .get(usize::try_from(self.selected).ok()?) - .map(LlamaTokenData::id) + self.data.get(self.selected?).map(LlamaTokenData::id) } } @@ -82,29 +80,40 @@ impl LlamaTokenDataArray { modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, ) -> T { let size = self.data.len(); - let data = self.data.as_mut_ptr().cast(); + let data = self + .data + .as_mut_ptr() + .cast::(); + let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array { data, size, - selected: self.selected, + selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1), sorted: self.sorted, }; + let result = modify(&mut c_llama_token_data_array); assert!( ptr::eq(data, c_llama_token_data_array.data), "data pointer changed" ); assert!(c_llama_token_data_array.size <= size, "size increased"); + self.data.set_len(c_llama_token_data_array.size); self.sorted = c_llama_token_data_array.sorted; - self.selected = c_llama_token_data_array.selected; + self.selected = c_llama_token_data_array + .selected + .try_into() + .ok() + .filter(|&s| s < self.data.len()); + result } pub(crate) unsafe fn apply_sampler(&mut self, sampler: *mut llama_cpp_sys_2::llama_sampler) { self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { llama_cpp_sys_2::llama_sampler_apply(sampler, c_llama_token_data_array); - }) + }); } pub(crate) unsafe fn apply_and_free_sampler( From 27ebd829b946f2f4b08046287a4ea0524b38f6ad Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sat, 7 Dec 2024 13:16:34 -0600 Subject: [PATCH 06/16] Overhaul sampling API --- examples/simple/src/main.rs | 10 +- llama-cpp-2/src/lib.rs | 8 - llama-cpp-2/src/sampling.rs | 374 ++++++++-------------------- llama-cpp-2/src/sampling/params.rs | 180 +++++++++++-- llama-cpp-2/src/token/data_array.rs | 271 ++------------------ 5 files changed, 280 insertions(+), 563 deletions(-) diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index 73932d37..f31a83c4 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -17,7 +17,7 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::model::{AddBos, Special}; -use llama_cpp_2::sampling::params::LlamaSamplerChainParams; +use llama_cpp_2::sampling::params::LlamaSamplerParams; use llama_cpp_2::sampling::LlamaSampler; use std::ffi::CString; @@ -246,10 +246,10 @@ either reduce n_len or increase n_ctx" // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); - let sampler_params = LlamaSamplerChainParams::default(); - let mut sampler = LlamaSampler::new(sampler_params)? - .add_dist(seed.unwrap_or(1234)) - .add_greedy(); + let mut sampler = LlamaSampler::new(LlamaSamplerParams::chain(&[ + LlamaSamplerParams::Dist { seed: seed.unwrap_or(1234) }, + LlamaSamplerParams::Greedy, + ])); while n_cur <= n_len { // sample the next token diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 424572bd..8e09608f 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -195,14 +195,6 @@ pub enum LlamaLoraAdapterRemoveError { ErrorResult(i32), } -/// An error that can occur when initializing a sampler. -#[derive(Debug, Eq, PartialEq, thiserror::Error)] -pub enum LlamaSamplerError { - /// llama.cpp returned null - #[error("null reference from llama.cpp")] - NullReturn, -} - /// get the time (in microseconds) according to llama.cpp /// ``` /// # use llama_cpp_2::llama_time_us; diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 89b9bc1c..e0313f4e 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -3,17 +3,16 @@ pub mod params; use std::ffi::CString; use std::fmt::{Debug, Formatter}; -use std::ptr::NonNull; use crate::context::LlamaContext; -use crate::model::LlamaModel; use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; -use crate::LlamaSamplerError; + +use params::LlamaSamplerParams; /// A safe wrapper around `llama_sampler`. pub struct LlamaSampler { - pub(crate) sampler: NonNull, + pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler, } impl Debug for LlamaSampler { @@ -22,318 +21,145 @@ impl Debug for LlamaSampler { } } -impl LlamaSampler { - /// Create a new `LlamaSampler` from the given parameters. - /// # Errors - /// Returns an error if the underlying C++ code returns a null pointer. - pub fn new(params: params::LlamaSamplerChainParams) -> Result { - let sampler = unsafe { - NonNull::new(llama_cpp_sys_2::llama_sampler_chain_init( - params.sampler_chain_params, - )) - .ok_or(LlamaSamplerError::NullReturn) - }?; - - Ok(Self { sampler }) - } - - /// Samples the token with the largest probability. - #[must_use] - #[allow(unused_mut)] - pub fn add_greedy(mut self) -> Self { - unsafe { - let greedy_sampler = llama_cpp_sys_2::llama_sampler_init_greedy(); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), greedy_sampler); - } - - self - } - - /// Samples according to the probability distribution of the tokens. - #[must_use] - #[allow(unused_mut)] - pub fn add_dist(mut self, seed: u32) -> Self { - unsafe { - let dist_sampler = llama_cpp_sys_2::llama_sampler_init_dist(seed); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), dist_sampler); - } - - self - } - - /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" - #[must_use] - #[allow(unused_mut)] - pub fn add_top_k(mut self, k: i32) -> Self { - unsafe { - let top_k_sampler = llama_cpp_sys_2::llama_sampler_init_top_k(k); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), top_k_sampler); - } - - self - } - - /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" - #[must_use] - #[allow(unused_mut)] - pub fn add_top_p(mut self, p: f32, min_keep: usize) -> Self { - unsafe { - let top_p_sampler = llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), top_p_sampler); - } +unsafe fn new_inner(params: LlamaSamplerParams) -> *mut llama_cpp_sys_2::llama_sampler { + match params { + LlamaSamplerParams::Chain { no_perf, stages } => { + let chain = llama_cpp_sys_2::llama_sampler_chain_init( + llama_cpp_sys_2::llama_sampler_chain_params { no_perf }, + ); - self - } + for stage in stages { + llama_cpp_sys_2::llama_sampler_chain_add(chain, new_inner(*stage)); + } - /// Minimum P sampling as described in - #[must_use] - #[allow(unused_mut)] - pub fn add_min_p(mut self, p: f32, min_keep: usize) -> Self { - unsafe { - let min_p_sampler = llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), min_p_sampler); + chain } - - self - } - - /// Locally Typical Sampling implementation described in the paper . - #[must_use] - #[allow(unused_mut)] - pub fn add_typical(mut self, p: f32, min_keep: usize) -> Self { - unsafe { - let typical_sampler = llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), typical_sampler); + LlamaSamplerParams::Temp(p) => llama_cpp_sys_2::llama_sampler_init_temp(p), + LlamaSamplerParams::TempExt { t, delta, exponent } => { + llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) } - - self - } - - /// Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf - #[must_use] - #[allow(unused_mut)] - pub fn add_temp(mut self, t: f32) -> Self { - unsafe { - let temp_sampler = llama_cpp_sys_2::llama_sampler_init_temp(t); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_sampler); + LlamaSamplerParams::TopK(k) => llama_cpp_sys_2::llama_sampler_init_top_k(k), + LlamaSamplerParams::Typical { p, min_keep } => { + llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) } - - self - } - - /// Dynamic temperature implementation (a.k.a. entropy) described in the paper . - #[must_use] - #[allow(unused_mut)] - pub fn add_temp_ext(mut self, t: f32, delta: f32, exponent: f32) -> Self { - unsafe { - let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + LlamaSamplerParams::TopP { p, min_keep } => { + llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) } - - self - } - - /// XTC sampling as described in . - #[must_use] - #[allow(unused_mut)] - pub fn add_xtc(mut self, p: f32, t: f32, min_keep: usize, seed: u32) -> Self { - unsafe { - let xtc_sampler = llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), xtc_sampler); + LlamaSamplerParams::MinP { p, min_keep } => { + llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) } - - self - } - - /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. - /// - /// # Arguments - /// - /// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// * `m` - The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - #[must_use] - #[allow(unused_mut)] - pub fn add_mirostat(mut self, n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self { - unsafe { - let temp_ext_sampler = - llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); - } - - self - } - - /// Mirostat 2.0 algorithm described in the paper . Uses tokens instead of words. - /// - /// # Arguments - /// - /// * `tau` - The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// * `eta` - The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// * `mu` - Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - #[must_use] - #[allow(unused_mut)] - pub fn add_mirostat_v2(mut self, seed: u32, tau: f32, eta: f32) -> Self { - unsafe { - let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); - } - - self - } - - /// Samples constrained by a context-free grammar in the GGML BNF (GBNF) format. - /// - /// # Panics - /// Panics if a provided string contains a null byte. - #[must_use] - #[allow(unused_mut)] - pub fn add_grammar( - mut self, - model: &LlamaModel, - grammar_str: &str, - grammar_root: &str, - ) -> Self { - unsafe { - let grammar_str = CString::new(grammar_str).unwrap(); - let grammar_root = CString::new(grammar_root).unwrap(); - let grammar_sampler = llama_cpp_sys_2::llama_sampler_init_grammar( + LlamaSamplerParams::Xtc { + p, + t, + min_keep, + seed, + } => llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed), + LlamaSamplerParams::Grammar { + model, + string, + root, + } => { + let grammar_str = CString::new(string).unwrap(); + let grammar_root = CString::new(root).unwrap(); + llama_cpp_sys_2::llama_sampler_init_grammar( model.model.as_ptr(), grammar_str.as_ptr(), grammar_root.as_ptr(), - ); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), grammar_sampler); + ) } - - self - } - - /// Adds penalties to the sampler. This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. - #[allow(unused_mut, clippy::too_many_arguments)] - #[must_use] - pub fn add_penalties( - mut self, - n_vocab: i32, - special_eos_id: i32, - linefeed_id: i32, - penalty_last_n: i32, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - penalize_nl: bool, - ignore_eos: bool, - ) -> Self { - unsafe { - let temp_ext_sampler = llama_cpp_sys_2::llama_sampler_init_penalties( - n_vocab, - special_eos_id, - linefeed_id, + LlamaSamplerParams::Dry { + model, + multiplier, + base, + allowed_length, + penalty_last_n, + seq_breakers, + } => { + let seq_breakers: Vec = seq_breakers + .iter() + .map(|s| CString::new(*s).unwrap()) + .collect(); + let mut seq_breaker_pointers: Vec<*const i8> = + seq_breakers.iter().map(|s| s.as_ptr()).collect(); + llama_cpp_sys_2::llama_sampler_init_dry( + model.model.as_ptr(), + multiplier, + base, + allowed_length, penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - penalize_nl, - ignore_eos, - ); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), temp_ext_sampler); + seq_breaker_pointers.as_mut_ptr(), + seq_breaker_pointers.len(), + ) } - - self - } - - /// Adds penalties to the sampler. This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. - #[allow(unused_mut)] - #[must_use] - pub fn add_penalties_simple( - mut self, - model: &LlamaModel, - penalty_last_n: i32, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) -> Self { - self.add_penalties( - model.n_vocab(), - model.token_eos().0, - model.token_nl().0, + LlamaSamplerParams::Penalties { + n_vocab, + special_eos_id, + linefeed_id, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, - false, - true, - ) + penalize_nl, + ignore_eos, + } => llama_cpp_sys_2::llama_sampler_init_penalties( + n_vocab, + special_eos_id, + linefeed_id, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + penalize_nl, + ignore_eos, + ), + LlamaSamplerParams::Dist { seed } => llama_cpp_sys_2::llama_sampler_init_dist(seed), + LlamaSamplerParams::Greedy => llama_cpp_sys_2::llama_sampler_init_greedy(), } +} - /// Adds DRY repetition penalty to the sampler. - /// - /// DRY sampler, designed by p-e-w, as described in: , porting Koboldcpp implementation authored by pi6am: - #[allow(unused_mut)] +impl LlamaSampler { + /// Create a new `LlamaSampler` from the given parameters. #[must_use] - pub fn add_dry( - mut self, - model: &LlamaModel, - dry_multiplier: f32, - dry_base: f32, - dry_allowed_length: i32, - dry_penalty_last_n: i32, - seq_breakers: &[impl AsRef<[u8]>], - ) -> Self { - let seq_breakers: Vec = seq_breakers - .iter() - .map(|s| { - let bytes = s.as_ref(); - let null_byte = bytes.iter().position(|b| *b == 0).unwrap_or(bytes.len()); - CString::new(&bytes[..null_byte]).expect("Failed to slice away null bytes!") - }) - .collect(); - - let mut seq_breaker_pointers: Vec<*const i8> = - seq_breakers.iter().map(|s| s.as_ptr()).collect(); - - unsafe { - // Memory safety: llama_sampler_init_dry does not hold a reference to - // seq_breaker_pointers, so this will not UAF in future operations. - let dry_sampler = llama_cpp_sys_2::llama_sampler_init_dry( - model.model.as_ptr(), - dry_multiplier, - dry_base, - dry_allowed_length, - dry_penalty_last_n, - seq_breaker_pointers.as_mut_ptr(), - seq_breaker_pointers.len(), - ); - llama_cpp_sys_2::llama_sampler_chain_add(self.sampler.as_ptr(), dry_sampler); + pub fn new(params: LlamaSamplerParams) -> Self { + Self { + sampler: unsafe { new_inner(params) }, } - - self } /// Sample and accept a token from the idx-th output of the last evaluation #[must_use] pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken { let token = unsafe { - llama_cpp_sys_2::llama_sampler_sample(self.sampler.as_ptr(), ctx.context.as_ptr(), idx) + llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx) }; LlamaToken(token) } + /// Applies this sampler to a [`LlamaTokenDataArray`]. pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) { - unsafe { data_array.apply_sampler(self.sampler.as_ptr()) } + data_array.apply_sampler(self); } - /// Accepts a token from the sampler, possibly updating the internal state of certain samplers (e.g. grammar, repetition, etc.) + /// Accepts a token from the sampler, possibly updating the internal state of certain samplers + /// (e.g. grammar, repetition, etc.) pub fn accept(&mut self, token: LlamaToken) { - unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler.as_ptr(), token.0) } + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) } + } + + /// Accepts several tokens from the sampler or context, possibly updating the internal state of + /// certain samplers (e.g. grammar, repetition, etc.) + pub fn accept_many(&mut self, tokens: &[LlamaToken]) { + for token in tokens { + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) } + } } } impl Drop for LlamaSampler { fn drop(&mut self) { unsafe { - llama_cpp_sys_2::llama_sampler_free(self.sampler.as_ptr()); + llama_cpp_sys_2::llama_sampler_free(self.sampler); } } } diff --git a/llama-cpp-2/src/sampling/params.rs b/llama-cpp-2/src/sampling/params.rs index 0e67c1fa..fe5d23e2 100644 --- a/llama-cpp-2/src/sampling/params.rs +++ b/llama-cpp-2/src/sampling/params.rs @@ -1,39 +1,171 @@ -//! Safe wrapper around `llama_sampler_chain_params`. +//! Safe parameters used to construct [`super::LlamaSampler`] -use std::fmt::{Debug, Formatter}; +/// Safe parameters used to construct [`super::LlamaSampler`] +#[derive(Debug, Clone, Copy)] +pub enum LlamaSamplerParams<'a> { + /// A chain of samplers, applied one after another + #[allow(missing_docs)] + Chain { + no_perf: bool, + stages: &'a [LlamaSamplerParams<'a>], + }, -/// A safe wrapper around `llama_sampler`. -pub struct LlamaSamplerChainParams { - pub(crate) sampler_chain_params: llama_cpp_sys_2::llama_sampler_chain_params, -} + /// Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original + /// value, the rest are set to -inf + Temp(f32), -impl Debug for LlamaSamplerChainParams { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("LlamaSamplerChainParams").finish() - } + /// Dynamic temperature implementation (a.k.a. entropy) described in the paper . + #[allow(missing_docs)] + TempExt { t: f32, delta: f32, exponent: f32 }, + /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" + /// + TopK(i32), + /// Locally Typical Sampling implementation described in the paper . + #[allow(missing_docs)] + Typical { p: f32, min_keep: usize }, + /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" + /// + #[allow(missing_docs)] + TopP { p: f32, min_keep: usize }, + /// Minimum P sampling as described in + #[allow(missing_docs)] + MinP { p: f32, min_keep: usize }, + + /// XTC sampler as described in + #[allow(missing_docs)] + Xtc { + /// The probability of this sampler being applied. + p: f32, + t: f32, + min_keep: usize, + /// Seed to use when selecting whether to apply this sampler or not + seed: u32, + }, + + /// Grammar sampler + #[allow(missing_docs)] + Grammar { + model: &'a crate::model::LlamaModel, + string: &'a str, + root: &'a str, + }, + + /// @details DRY sampler, designed by p-e-w, as described in: + /// , porting Koboldcpp + /// implementation authored by pi6am: + #[allow(missing_docs)] + Dry { + model: &'a crate::model::LlamaModel, + multiplier: f32, + base: f32, + allowed_length: i32, + penalty_last_n: i32, + seq_breakers: &'a [&'a str], + }, + + /// Penalizes tokens for being present in the context. + Penalties { + /// ``model.n_vocab()`` + n_vocab: i32, + /// ``model.token_eos()`` + special_eos_id: i32, + /// ``model.token_nl()`` + linefeed_id: i32, + /// last n tokens to penalize (0 = disable penalty, -1 = context size) + penalty_last_n: i32, + /// 1.0 = disabled + penalty_repeat: f32, + /// 0.0 = disabled + penalty_freq: f32, + /// 0.0 = disabled + penalty_present: f32, + /// consider newlines as a repeatable token + penalize_nl: bool, + /// ignore the end-of-sequence token + ignore_eos: bool, + }, + + /// Select a token at random based on each token's probabilities + Dist { + /// Seed to initialize random generation with + seed: u32, + }, + + /// Select the most likely token + Greedy, } -impl Default for LlamaSamplerChainParams { - fn default() -> Self { - let sampler_chain_params = unsafe { llama_cpp_sys_2::llama_sampler_chain_default_params() }; +impl<'a> LlamaSamplerParams<'a> { + /// Easily create a chain of samplers with performance metrics enabled. + #[must_use] + pub fn chain(stages: &'a [Self]) -> Self { + LlamaSamplerParams::Chain { + no_perf: false, + stages, + } + } - Self { - sampler_chain_params, + /// Easily create a [`LlamaSamplerParams::Penalties`] sampler using a model. This sets + /// `penalize_nl` to false and `ignore_eos` to true as reasonable defaults. + #[must_use] + pub fn penalties( + model: &'a crate::model::LlamaModel, + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + ) -> Self { + Self::Penalties { + n_vocab: model.n_vocab(), + special_eos_id: model.token_eos().0, + linefeed_id: model.token_nl().0, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + penalize_nl: false, + ignore_eos: true, } } -} -impl LlamaSamplerChainParams { - /// Set whether to measure performance timings + /// Easily define a [`LlamaSamplerParams::Typical`] with `min_keep == 1` + #[must_use] + pub fn typical(p: f32) -> Self { + Self::Typical { p, min_keep: 1 } + } + + /// Easily define a [`LlamaSamplerParams::TopP`] with `min_keep == 1` #[must_use] - pub fn with_no_perf(mut self, no_perf: bool) -> Self { - self.sampler_chain_params.no_perf = no_perf; - self + pub fn top_p(p: f32) -> Self { + Self::TopP { p, min_keep: 1 } } - /// Get whether to measure performance timings + /// Easily define a [`LlamaSamplerParams::MinP`] with `min_keep == 1` #[must_use] - pub fn no_perf(&self) -> bool { - self.sampler_chain_params.no_perf + pub fn min_p(p: f32) -> Self { + Self::MinP { p, min_keep: 1 } + } + + /// Whether this sampler's outputs are dependent on the tokens in the model's context. + pub(crate) fn uses_context_tokens(&self) -> bool { + match self { + LlamaSamplerParams::Chain { stages, .. } => { + stages.iter().any(LlamaSamplerParams::uses_context_tokens) + } + + LlamaSamplerParams::Grammar { .. } + | LlamaSamplerParams::Penalties { .. } + | LlamaSamplerParams::Dry { .. } => true, + + LlamaSamplerParams::Temp(_) + | LlamaSamplerParams::TempExt { .. } + | LlamaSamplerParams::TopK(_) + | LlamaSamplerParams::Typical { .. } + | LlamaSamplerParams::TopP { .. } + | LlamaSamplerParams::MinP { .. } + | LlamaSamplerParams::Xtc { .. } + | LlamaSamplerParams::Dist { .. } + | LlamaSamplerParams::Greedy => false, + } } } diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 090c866e..97a1a600 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -1,7 +1,11 @@ //! an rusty equivalent of `llama_token_data_array`. use std::{ffi::CString, ptr}; -use crate::{model::LlamaModel, token::data::LlamaTokenData}; +use crate::{ + model::LlamaModel, + sampling::{params::LlamaSamplerParams, LlamaSampler}, + token::data::LlamaTokenData, +}; use super::LlamaToken; @@ -58,6 +62,7 @@ impl LlamaTokenDataArray { Self::new(data.into_iter().collect(), sorted) } + /// Returns the current selected token, if one exists. #[must_use] pub fn selected_token(&self) -> Option { self.data.get(self.selected?).map(LlamaTokenData::id) @@ -110,275 +115,37 @@ impl LlamaTokenDataArray { result } - pub(crate) unsafe fn apply_sampler(&mut self, sampler: *mut llama_cpp_sys_2::llama_sampler) { - self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { - llama_cpp_sys_2::llama_sampler_apply(sampler, c_llama_token_data_array); - }); - } - - pub(crate) unsafe fn apply_and_free_sampler( - &mut self, - sampler_fn: impl FnOnce() -> *mut llama_cpp_sys_2::llama_sampler, - ) { - let sampler = sampler_fn(); - self.apply_sampler(sampler); - llama_cpp_sys_2::llama_sampler_free(sampler); - } - - /// Modify the logits of [`Self`] in place using temperature sampling. - /// - /// # Example - /// - /// ```rust - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0) - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// - /// candidates.sample_temp(0.5); - /// - /// assert_eq!(candidates.data[0].logit(), 0.2); - /// assert_eq!(candidates.data[1].logit(), 0.4); - /// assert_eq!(candidates.data[2].logit(), 1.4); - /// ``` - pub fn sample_temp(&mut self, temperature: f32) { - unsafe { - self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_temp(temperature)); - } - } + /// Applies a sampler constructed from [`LlamaSamplerParams`]. This will call + /// [`LlamaSampler::accept_many`] on the provided tokens if the sampler uses tokens. + pub fn apply_sampler_from_params(&mut self, params: LlamaSamplerParams, tokens: &[LlamaToken]) { + let mut sampler = LlamaSampler::new(params); - /// Dynamic temperature implementation (a.k.a. entropy) described in the paper . - pub fn sample_temp_ext(&mut self, t: f32, delta: f32, exponent: f32) { - unsafe { - self.apply_and_free_sampler(|| { - llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) - }); + if params.uses_context_tokens() { + sampler.accept_many(tokens); } - } - /// Top-K sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - pub fn sample_top_k(&mut self, k: i32) { - unsafe { - self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_top_k(k)); - } + self.apply_sampler(&mut sampler); } - /// Locally Typical Sampling implementation described in the [paper](https://arxiv.org/abs/2202.00666). - /// - /// # Example - /// - /// ```rust - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_typical(0.5, 1); - /// ``` - pub fn sample_typical(&mut self, p: f32, min_keep: usize) { + /// Modifies the data array by applying a sampler to it + pub fn apply_sampler(&mut self, sampler: &mut LlamaSampler) { unsafe { - self.apply_and_free_sampler(|| { - llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) - }); - } - } - - /// Nucleus sampling described in academic paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - /// - /// # Example - /// - /// ```rust - /// - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_top_p(0.5, 1); - /// - /// assert_eq!(candidates.data.len(), 2); - /// assert_eq!(candidates.data[0].id(), LlamaToken::new(2)); - /// assert_eq!(candidates.data[1].id(), LlamaToken::new(1)); - /// ``` - pub fn sample_top_p(&mut self, p: f32, min_keep: usize) { - unsafe { - self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep)); - } - } - - /// Minimum P sampling as described in [#3841](https://github.com/ggerganov/llama.cpp/pull/3841) - /// - /// # Example - /// - /// ``` - /// # use llama_cpp_2::token::data::LlamaTokenData; - /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; - /// # use llama_cpp_2::token::LlamaToken; - /// - /// let candidates = vec![ - /// LlamaTokenData::new(LlamaToken::new(4), -2., 0.0), - /// LlamaTokenData::new(LlamaToken::new(0), 0.1, 0.0), - /// LlamaTokenData::new(LlamaToken::new(1), 0.2, 0.0), - /// LlamaTokenData::new(LlamaToken::new(2), 0.7, 0.0), - /// ]; - /// let mut candidates = LlamaTokenDataArray::from_iter(candidates, false); - /// candidates.sample_min_p(0.1, 1); - /// - /// assert_eq!(candidates.data.len(), 3); - /// ``` - pub fn sample_min_p(&mut self, p: f32, min_keep: usize) { - unsafe { - self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep)); - } - } - - /// XTC sampling as described in . - pub fn sample_xtc(&mut self, p: f32, t: f32, min_keep: usize, seed: u32) { - unsafe { - self.apply_and_free_sampler(|| { - llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) - }); - } - } - - /// This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. - #[allow(clippy::too_many_arguments)] - pub fn sample_penalties( - &mut self, - tokens: &[LlamaToken], - n_vocab: i32, - special_eos_id: i32, - linefeed_id: i32, - penalty_last_n: i32, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - penalize_nl: bool, - ignore_eos: bool, - ) { - unsafe { - self.apply_and_free_sampler(|| { - let sampler = llama_cpp_sys_2::llama_sampler_init_penalties( - n_vocab, - special_eos_id, - linefeed_id, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - penalize_nl, - ignore_eos, - ); - - for token in tokens { - llama_cpp_sys_2::llama_sampler_accept(sampler, token.0); - } - - sampler - }); - } - } - - /// This can be used to penalize certain patterns in the generated text, such as repeating the same token multiple times or using the same token too frequently. - pub fn sample_penalties_simple( - &mut self, - tokens: &[LlamaToken], - model: &LlamaModel, - penalty_last_n: i32, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) { - self.sample_penalties( - tokens, - model.n_vocab(), - model.token_eos().0, - model.token_nl().0, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - false, - true, - ); - } - - /// DRY sampler, designed by p-e-w, as described in: , porting Koboldcpp implementation authored by pi6am: - #[allow(clippy::too_many_arguments)] - pub fn sample_dry( - &mut self, - tokens: &[LlamaToken], - model: &LlamaModel, - dry_multiplier: f32, - dry_base: f32, - dry_allowed_length: i32, - dry_penalty_last_n: i32, - seq_breakers: &[impl AsRef<[u8]>], - ) { - let seq_breakers: Vec = seq_breakers - .iter() - .map(|s| { - let bytes = s.as_ref(); - let null_byte = bytes.iter().position(|b| *b == 0).unwrap_or(bytes.len()); - CString::new(&bytes[..null_byte]).expect("Failed to slice away null bytes!") - }) - .collect(); - - let mut seq_breaker_pointers: Vec<*const i8> = - seq_breakers.iter().map(|s| s.as_ptr()).collect(); - - unsafe { - self.apply_and_free_sampler(|| { - let sampler = llama_cpp_sys_2::llama_sampler_init_dry( - model.model.as_ptr(), - dry_multiplier, - dry_base, - dry_allowed_length, - dry_penalty_last_n, - seq_breaker_pointers.as_mut_ptr(), - seq_breaker_pointers.len(), - ); - - for token in tokens { - llama_cpp_sys_2::llama_sampler_accept(sampler, token.0); - } - - sampler + self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sampler_apply(sampler.sampler, c_llama_token_data_array); }); } } /// Randomly selects a token from the candidates based on their probabilities. pub fn sample_token(&mut self, seed: u32) -> LlamaToken { - unsafe { - self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_dist(seed)); - } + self.apply_sampler_from_params(LlamaSamplerParams::Dist { seed }, &[]); self.selected_token() .expect("Dist sampler failed to select a token!") } /// Selects the token with the highest probability. pub fn sample_token_greedy(&mut self) -> LlamaToken { - unsafe { - self.apply_and_free_sampler(|| llama_cpp_sys_2::llama_sampler_init_greedy()); - } + self.apply_sampler_from_params(LlamaSamplerParams::Greedy, &[]); self.selected_token() .expect("Greedy sampler failed to select a token!") } From d44deef3d9c5e50aaf86ca9a0e707c50389ccd2d Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sat, 7 Dec 2024 13:25:33 -0600 Subject: [PATCH 07/16] Add Mirostat to new API --- examples/simple/src/main.rs | 4 +++- llama-cpp-2/src/sampling.rs | 10 ++++++++++ llama-cpp-2/src/sampling/params.rs | 28 +++++++++++++++++++++++++++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index f31a83c4..e13274f1 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -247,7 +247,9 @@ either reduce n_len or increase n_ctx" let mut decoder = encoding_rs::UTF_8.new_decoder(); let mut sampler = LlamaSampler::new(LlamaSamplerParams::chain(&[ - LlamaSamplerParams::Dist { seed: seed.unwrap_or(1234) }, + LlamaSamplerParams::Dist { + seed: seed.unwrap_or(1234), + }, LlamaSamplerParams::Greedy, ])); diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index e0313f4e..2c945326 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -112,6 +112,16 @@ unsafe fn new_inner(params: LlamaSamplerParams) -> *mut llama_cpp_sys_2::llama_s penalize_nl, ignore_eos, ), + LlamaSamplerParams::Mirostat { + n_vocab, + tau, + eta, + m, + seed, + } => llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m), + LlamaSamplerParams::MirostatV2 { tau, eta, seed } => { + llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) + } LlamaSamplerParams::Dist { seed } => llama_cpp_sys_2::llama_sampler_init_dist(seed), LlamaSamplerParams::Greedy => llama_cpp_sys_2::llama_sampler_init_greedy(), } diff --git a/llama-cpp-2/src/sampling/params.rs b/llama-cpp-2/src/sampling/params.rs index fe5d23e2..84cdbc3f 100644 --- a/llama-cpp-2/src/sampling/params.rs +++ b/llama-cpp-2/src/sampling/params.rs @@ -85,6 +85,30 @@ pub enum LlamaSamplerParams<'a> { ignore_eos: bool, }, + /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. + Mirostat { + /// ``model.n_vocab()`` + n_vocab: i32, + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + tau: f32, + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + eta: f32, + /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + m: i32, + /// Seed to initialize random generation with + seed: u32, + }, + + /// Mirostat 2.0 algorithm described in the paper . Uses tokens instead of words. + MirostatV2 { + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + tau: f32, + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + eta: f32, + /// Seed to initialize random generation with + seed: u32, + }, + /// Select a token at random based on each token's probabilities Dist { /// Seed to initialize random generation with @@ -146,7 +170,7 @@ impl<'a> LlamaSamplerParams<'a> { Self::MinP { p, min_keep: 1 } } - /// Whether this sampler's outputs are dependent on the tokens in the model's context. + /// Whether this sampler's outputs are dependent on the tokens in the model's context. pub(crate) fn uses_context_tokens(&self) -> bool { match self { LlamaSamplerParams::Chain { stages, .. } => { @@ -164,6 +188,8 @@ impl<'a> LlamaSamplerParams<'a> { | LlamaSamplerParams::TopP { .. } | LlamaSamplerParams::MinP { .. } | LlamaSamplerParams::Xtc { .. } + | LlamaSamplerParams::Mirostat { .. } + | LlamaSamplerParams::MirostatV2 { .. } | LlamaSamplerParams::Dist { .. } | LlamaSamplerParams::Greedy => false, } From 4a334a44f5c43dd27d5bf5832d6e176222c5f57f Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sat, 7 Dec 2024 14:19:54 -0600 Subject: [PATCH 08/16] Remove unused imports --- llama-cpp-2/src/token/data_array.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 97a1a600..3e34dc69 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -1,8 +1,7 @@ //! an rusty equivalent of `llama_token_data_array`. -use std::{ffi::CString, ptr}; +use std::ptr; use crate::{ - model::LlamaModel, sampling::{params::LlamaSamplerParams, LlamaSampler}, token::data::LlamaTokenData, }; From 32cadf765e1125aa6f67bfff76f6f9a333571b89 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sat, 7 Dec 2024 16:13:19 -0600 Subject: [PATCH 09/16] Fix crash when running XTC sampler --- llama-cpp-2/src/token/data_array.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 3e34dc69..8d3266a0 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -73,12 +73,14 @@ impl LlamaTokenDataArray { /// /// # Panics /// - /// Panics if some of the safety conditions are not met. (we cannot check all of them at runtime so breaking them is UB) + /// Panics if some of the safety conditions are not met. (we cannot check all of them at + /// runtime so breaking them is UB) /// /// SAFETY: - /// [modify] cannot change the data pointer. + /// The returned array formed by the data pointer and the length must entirely consist of + /// initialized token data and the length must be less than the capacity of this array's data + /// buffer. /// if the data is not sorted, sorted must be false. - /// the size of the data can only decrease (i.e you cannot add new elements). pub(crate) unsafe fn modify_as_c_llama_token_data_array( &mut self, modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, @@ -97,13 +99,20 @@ impl LlamaTokenDataArray { }; let result = modify(&mut c_llama_token_data_array); + assert!( - ptr::eq(data, c_llama_token_data_array.data), - "data pointer changed" + c_llama_token_data_array.size <= self.data.capacity(), + "Size of the returned array exceeds the data buffer's capacity!" ); - assert!(c_llama_token_data_array.size <= size, "size increased"); - + if !ptr::eq(c_llama_token_data_array.data, data) { + ptr::copy( + c_llama_token_data_array.data, + data, + c_llama_token_data_array.size, + ); + } self.data.set_len(c_llama_token_data_array.size); + self.sorted = c_llama_token_data_array.sorted; self.selected = c_llama_token_data_array .selected From 73ef067bc0cf73379cd38592cfbf768f06b4ce14 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 8 Dec 2024 13:51:40 -0600 Subject: [PATCH 10/16] Yet another API overhaul --- examples/simple/src/main.rs | 11 +- examples/usage.rs | 8 +- llama-cpp-2/src/sampling.rs | 287 ++++++++++++++++++---------- llama-cpp-2/src/sampling/params.rs | 197 ------------------- llama-cpp-2/src/token/data_array.rs | 18 +- 5 files changed, 190 insertions(+), 331 deletions(-) delete mode 100644 llama-cpp-2/src/sampling/params.rs diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index e13274f1..f67a5309 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -17,7 +17,6 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::model::{AddBos, Special}; -use llama_cpp_2::sampling::params::LlamaSamplerParams; use llama_cpp_2::sampling::LlamaSampler; use std::ffi::CString; @@ -246,12 +245,10 @@ either reduce n_len or increase n_ctx" // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); - let mut sampler = LlamaSampler::new(LlamaSamplerParams::chain(&[ - LlamaSamplerParams::Dist { - seed: seed.unwrap_or(1234), - }, - LlamaSamplerParams::Greedy, - ])); + let mut sampler = LlamaSampler::chain(vec![ + LlamaSampler::dist(seed.unwrap_or(1234)), + LlamaSampler::greedy(), + ]); while n_cur <= n_len { // sample the next token diff --git a/examples/usage.rs b/examples/usage.rs index 2b7f1915..323ad6c2 100644 --- a/examples/usage.rs +++ b/examples/usage.rs @@ -14,9 +14,7 @@ use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::params::LlamaModelParams; use llama_cpp_2::model::LlamaModel; use llama_cpp_2::model::{AddBos, Special}; -use llama_cpp_2::sampling::params::LlamaSamplerChainParams; use llama_cpp_2::sampling::LlamaSampler; -use llama_cpp_2::token::data_array::LlamaTokenDataArray; use std::io::Write; #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] @@ -55,11 +53,7 @@ fn main() { // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); - - let sampler_params = LlamaSamplerChainParams::default(); - let mut sampler = LlamaSampler::new(sampler_params) - .expect("Failed to create sampler") - .add_greedy(); + let mut sampler = LlamaSampler::greedy(); while n_cur <= n_len { // sample the next token diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 2c945326..f52ea1a8 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -1,15 +1,13 @@ //! Safe wrapper around `llama_sampler`. -pub mod params; use std::ffi::CString; use std::fmt::{Debug, Formatter}; use crate::context::LlamaContext; +use crate::model::LlamaModel; use crate::token::data_array::LlamaTokenDataArray; use crate::token::LlamaToken; -use params::LlamaSamplerParams; - /// A safe wrapper around `llama_sampler`. pub struct LlamaSampler { pub(crate) sampler: *mut llama_cpp_sys_2::llama_sampler, @@ -21,63 +19,138 @@ impl Debug for LlamaSampler { } } -unsafe fn new_inner(params: LlamaSamplerParams) -> *mut llama_cpp_sys_2::llama_sampler { - match params { - LlamaSamplerParams::Chain { no_perf, stages } => { +impl LlamaSampler { + /// Sample and accept a token from the idx-th output of the last evaluation + #[must_use] + pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken { + let token = unsafe { + llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx) + }; + + LlamaToken(token) + } + + /// Applies this sampler to a [`LlamaTokenDataArray`]. + pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) { + data_array.apply_sampler(self); + } + + /// Accepts a token from the sampler, possibly updating the internal state of certain samplers + /// (e.g. grammar, repetition, etc.) + pub fn accept(&mut self, token: LlamaToken) { + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) } + } + + /// Accepts several tokens from the sampler or context, possibly updating the internal state of + /// certain samplers (e.g. grammar, repetition, etc.) + pub fn accept_many(&mut self, tokens: &[LlamaToken]) { + for token in tokens { + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) } + } + } + + /// Accepts several tokens from the sampler or context, possibly updating the internal state of + /// certain samplers (e.g. grammar, repetition, etc.) + #[must_use] + pub fn with_tokens(mut self, tokens: &[LlamaToken]) -> Self { + self.accept_many(tokens); + self + } + + #[must_use] + pub fn chain_with_no_perf(samplers: Vec, no_perf: bool) -> Self { + unsafe { let chain = llama_cpp_sys_2::llama_sampler_chain_init( llama_cpp_sys_2::llama_sampler_chain_params { no_perf }, ); - for stage in stages { - llama_cpp_sys_2::llama_sampler_chain_add(chain, new_inner(*stage)); + for sampler in samplers { + llama_cpp_sys_2::llama_sampler_chain_add(chain, sampler.sampler); + + // Do not call `llama_sampler_free` on the sampler, as the internal sampler is now + // owned by the chain + std::mem::forget(sampler); } - chain - } - LlamaSamplerParams::Temp(p) => llama_cpp_sys_2::llama_sampler_init_temp(p), - LlamaSamplerParams::TempExt { t, delta, exponent } => { - llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) - } - LlamaSamplerParams::TopK(k) => llama_cpp_sys_2::llama_sampler_init_top_k(k), - LlamaSamplerParams::Typical { p, min_keep } => { - llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) - } - LlamaSamplerParams::TopP { p, min_keep } => { - llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) + Self { sampler: chain } } - LlamaSamplerParams::MinP { p, min_keep } => { - llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) - } - LlamaSamplerParams::Xtc { - p, - t, - min_keep, - seed, - } => llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed), - LlamaSamplerParams::Grammar { - model, - string, - root, - } => { - let grammar_str = CString::new(string).unwrap(); - let grammar_root = CString::new(root).unwrap(); + } + + #[must_use] + pub fn chain(samplers: Vec) -> Self { + Self::chain_with_no_perf(samplers, false) + } + + #[must_use] + pub fn temp(t: f32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) }; + Self { sampler } + } + + #[must_use] + pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) }; + Self { sampler } + } + + #[must_use] + pub fn top_k(k: i32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) }; + Self { sampler } + } + + #[must_use] + pub fn typical(p: f32, min_keep: usize) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) }; + Self { sampler } + } + + #[must_use] + pub fn top_p(p: f32, min_keep: usize) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) }; + Self { sampler } + } + + #[must_use] + pub fn min_p(p: f32, min_keep: usize) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) }; + Self { sampler } + } + + #[must_use] + pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) }; + Self { sampler } + } + + #[must_use] + pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self { + let grammar_str = CString::new(grammar_str).unwrap(); + let grammar_root = CString::new(grammar_root).unwrap(); + + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_grammar( model.model.as_ptr(), grammar_str.as_ptr(), grammar_root.as_ptr(), ) - } - LlamaSamplerParams::Dry { - model, - multiplier, - base, - allowed_length, - penalty_last_n, - seq_breakers, - } => { + }; + Self { sampler } + } + + #[must_use] + pub fn dry( + model: &LlamaModel, + multiplier: f32, + base: f32, + allowed_length: i32, + penalty_last_n: i32, + seq_breakers: &[impl AsRef<[u8]>], + ) -> Self { + let sampler = unsafe { let seq_breakers: Vec = seq_breakers .iter() - .map(|s| CString::new(*s).unwrap()) + .map(|s| CString::new(s.as_ref()).unwrap()) .collect(); let mut seq_breaker_pointers: Vec<*const i8> = seq_breakers.iter().map(|s| s.as_ptr()).collect(); @@ -90,79 +163,83 @@ unsafe fn new_inner(params: LlamaSamplerParams) -> *mut llama_cpp_sys_2::llama_s seq_breaker_pointers.as_mut_ptr(), seq_breaker_pointers.len(), ) - } - LlamaSamplerParams::Penalties { - n_vocab, - special_eos_id, - linefeed_id, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - penalize_nl, - ignore_eos, - } => llama_cpp_sys_2::llama_sampler_init_penalties( - n_vocab, - special_eos_id, - linefeed_id, + }; + Self { sampler } + } + + #[allow(clippy::too_many_arguments)] + #[must_use] + pub fn penalties( + n_vocab: i32, + special_eos_id: i32, + linefeed_id: i32, + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + penalize_nl: bool, + ignore_eos: bool, + ) -> Self { + let sampler = unsafe { + llama_cpp_sys_2::llama_sampler_init_penalties( + n_vocab, + special_eos_id, + linefeed_id, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + penalize_nl, + ignore_eos, + ) + }; + Self { sampler } + } + + #[must_use] + pub fn penalties_simple( + model: &LlamaModel, + penalty_last_n: i32, + penalty_repeat: f32, + penalty_freq: f32, + penalty_present: f32, + ) -> Self { + Self::penalties( + model.n_vocab(), + model.token_eos().0, + model.token_nl().0, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, - penalize_nl, - ignore_eos, - ), - LlamaSamplerParams::Mirostat { - n_vocab, - tau, - eta, - m, - seed, - } => llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m), - LlamaSamplerParams::MirostatV2 { tau, eta, seed } => { - llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) - } - LlamaSamplerParams::Dist { seed } => llama_cpp_sys_2::llama_sampler_init_dist(seed), - LlamaSamplerParams::Greedy => llama_cpp_sys_2::llama_sampler_init_greedy(), + false, + true, + ) } -} -impl LlamaSampler { - /// Create a new `LlamaSampler` from the given parameters. #[must_use] - pub fn new(params: LlamaSamplerParams) -> Self { - Self { - sampler: unsafe { new_inner(params) }, - } + pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self { + let sampler = + unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) }; + Self { sampler } } - /// Sample and accept a token from the idx-th output of the last evaluation #[must_use] - pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> LlamaToken { - let token = unsafe { - llama_cpp_sys_2::llama_sampler_sample(self.sampler, ctx.context.as_ptr(), idx) - }; - - LlamaToken(token) - } - - /// Applies this sampler to a [`LlamaTokenDataArray`]. - pub fn apply(&mut self, data_array: &mut LlamaTokenDataArray) { - data_array.apply_sampler(self); + pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) }; + Self { sampler } } - /// Accepts a token from the sampler, possibly updating the internal state of certain samplers - /// (e.g. grammar, repetition, etc.) - pub fn accept(&mut self, token: LlamaToken) { - unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) } + #[must_use] + pub fn dist(seed: u32) -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) }; + Self { sampler } } - /// Accepts several tokens from the sampler or context, possibly updating the internal state of - /// certain samplers (e.g. grammar, repetition, etc.) - pub fn accept_many(&mut self, tokens: &[LlamaToken]) { - for token in tokens { - unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) } - } + #[must_use] + pub fn greedy() -> Self { + let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() }; + Self { sampler } } } diff --git a/llama-cpp-2/src/sampling/params.rs b/llama-cpp-2/src/sampling/params.rs deleted file mode 100644 index 84cdbc3f..00000000 --- a/llama-cpp-2/src/sampling/params.rs +++ /dev/null @@ -1,197 +0,0 @@ -//! Safe parameters used to construct [`super::LlamaSampler`] - -/// Safe parameters used to construct [`super::LlamaSampler`] -#[derive(Debug, Clone, Copy)] -pub enum LlamaSamplerParams<'a> { - /// A chain of samplers, applied one after another - #[allow(missing_docs)] - Chain { - no_perf: bool, - stages: &'a [LlamaSamplerParams<'a>], - }, - - /// Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original - /// value, the rest are set to -inf - Temp(f32), - - /// Dynamic temperature implementation (a.k.a. entropy) described in the paper . - #[allow(missing_docs)] - TempExt { t: f32, delta: f32, exponent: f32 }, - /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" - /// - TopK(i32), - /// Locally Typical Sampling implementation described in the paper . - #[allow(missing_docs)] - Typical { p: f32, min_keep: usize }, - /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" - /// - #[allow(missing_docs)] - TopP { p: f32, min_keep: usize }, - /// Minimum P sampling as described in - #[allow(missing_docs)] - MinP { p: f32, min_keep: usize }, - - /// XTC sampler as described in - #[allow(missing_docs)] - Xtc { - /// The probability of this sampler being applied. - p: f32, - t: f32, - min_keep: usize, - /// Seed to use when selecting whether to apply this sampler or not - seed: u32, - }, - - /// Grammar sampler - #[allow(missing_docs)] - Grammar { - model: &'a crate::model::LlamaModel, - string: &'a str, - root: &'a str, - }, - - /// @details DRY sampler, designed by p-e-w, as described in: - /// , porting Koboldcpp - /// implementation authored by pi6am: - #[allow(missing_docs)] - Dry { - model: &'a crate::model::LlamaModel, - multiplier: f32, - base: f32, - allowed_length: i32, - penalty_last_n: i32, - seq_breakers: &'a [&'a str], - }, - - /// Penalizes tokens for being present in the context. - Penalties { - /// ``model.n_vocab()`` - n_vocab: i32, - /// ``model.token_eos()`` - special_eos_id: i32, - /// ``model.token_nl()`` - linefeed_id: i32, - /// last n tokens to penalize (0 = disable penalty, -1 = context size) - penalty_last_n: i32, - /// 1.0 = disabled - penalty_repeat: f32, - /// 0.0 = disabled - penalty_freq: f32, - /// 0.0 = disabled - penalty_present: f32, - /// consider newlines as a repeatable token - penalize_nl: bool, - /// ignore the end-of-sequence token - ignore_eos: bool, - }, - - /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. - Mirostat { - /// ``model.n_vocab()`` - n_vocab: i32, - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - tau: f32, - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - eta: f32, - /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - m: i32, - /// Seed to initialize random generation with - seed: u32, - }, - - /// Mirostat 2.0 algorithm described in the paper . Uses tokens instead of words. - MirostatV2 { - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - tau: f32, - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - eta: f32, - /// Seed to initialize random generation with - seed: u32, - }, - - /// Select a token at random based on each token's probabilities - Dist { - /// Seed to initialize random generation with - seed: u32, - }, - - /// Select the most likely token - Greedy, -} - -impl<'a> LlamaSamplerParams<'a> { - /// Easily create a chain of samplers with performance metrics enabled. - #[must_use] - pub fn chain(stages: &'a [Self]) -> Self { - LlamaSamplerParams::Chain { - no_perf: false, - stages, - } - } - - /// Easily create a [`LlamaSamplerParams::Penalties`] sampler using a model. This sets - /// `penalize_nl` to false and `ignore_eos` to true as reasonable defaults. - #[must_use] - pub fn penalties( - model: &'a crate::model::LlamaModel, - penalty_last_n: i32, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) -> Self { - Self::Penalties { - n_vocab: model.n_vocab(), - special_eos_id: model.token_eos().0, - linefeed_id: model.token_nl().0, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - penalize_nl: false, - ignore_eos: true, - } - } - - /// Easily define a [`LlamaSamplerParams::Typical`] with `min_keep == 1` - #[must_use] - pub fn typical(p: f32) -> Self { - Self::Typical { p, min_keep: 1 } - } - - /// Easily define a [`LlamaSamplerParams::TopP`] with `min_keep == 1` - #[must_use] - pub fn top_p(p: f32) -> Self { - Self::TopP { p, min_keep: 1 } - } - - /// Easily define a [`LlamaSamplerParams::MinP`] with `min_keep == 1` - #[must_use] - pub fn min_p(p: f32) -> Self { - Self::MinP { p, min_keep: 1 } - } - - /// Whether this sampler's outputs are dependent on the tokens in the model's context. - pub(crate) fn uses_context_tokens(&self) -> bool { - match self { - LlamaSamplerParams::Chain { stages, .. } => { - stages.iter().any(LlamaSamplerParams::uses_context_tokens) - } - - LlamaSamplerParams::Grammar { .. } - | LlamaSamplerParams::Penalties { .. } - | LlamaSamplerParams::Dry { .. } => true, - - LlamaSamplerParams::Temp(_) - | LlamaSamplerParams::TempExt { .. } - | LlamaSamplerParams::TopK(_) - | LlamaSamplerParams::Typical { .. } - | LlamaSamplerParams::TopP { .. } - | LlamaSamplerParams::MinP { .. } - | LlamaSamplerParams::Xtc { .. } - | LlamaSamplerParams::Mirostat { .. } - | LlamaSamplerParams::MirostatV2 { .. } - | LlamaSamplerParams::Dist { .. } - | LlamaSamplerParams::Greedy => false, - } - } -} diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 8d3266a0..0912af8a 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -2,7 +2,7 @@ use std::ptr; use crate::{ - sampling::{params::LlamaSamplerParams, LlamaSampler}, + sampling::LlamaSampler, token::data::LlamaTokenData, }; @@ -123,18 +123,6 @@ impl LlamaTokenDataArray { result } - /// Applies a sampler constructed from [`LlamaSamplerParams`]. This will call - /// [`LlamaSampler::accept_many`] on the provided tokens if the sampler uses tokens. - pub fn apply_sampler_from_params(&mut self, params: LlamaSamplerParams, tokens: &[LlamaToken]) { - let mut sampler = LlamaSampler::new(params); - - if params.uses_context_tokens() { - sampler.accept_many(tokens); - } - - self.apply_sampler(&mut sampler); - } - /// Modifies the data array by applying a sampler to it pub fn apply_sampler(&mut self, sampler: &mut LlamaSampler) { unsafe { @@ -146,14 +134,14 @@ impl LlamaTokenDataArray { /// Randomly selects a token from the candidates based on their probabilities. pub fn sample_token(&mut self, seed: u32) -> LlamaToken { - self.apply_sampler_from_params(LlamaSamplerParams::Dist { seed }, &[]); + self.apply_sampler(&mut LlamaSampler::dist(seed)); self.selected_token() .expect("Dist sampler failed to select a token!") } /// Selects the token with the highest probability. pub fn sample_token_greedy(&mut self) -> LlamaToken { - self.apply_sampler_from_params(LlamaSamplerParams::Greedy, &[]); + self.apply_sampler(&mut LlamaSampler::greedy()); self.selected_token() .expect("Greedy sampler failed to select a token!") } From bacad6574ec0836b6c35e6c59c147a219265e69f Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 8 Dec 2024 14:00:15 -0600 Subject: [PATCH 11/16] Small tweaks --- llama-cpp-2/src/sampling.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index f52ea1a8..67a2b57b 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -58,7 +58,7 @@ impl LlamaSampler { } #[must_use] - pub fn chain_with_no_perf(samplers: Vec, no_perf: bool) -> Self { + pub fn chain(samplers: Vec, no_perf: bool) -> Self { unsafe { let chain = llama_cpp_sys_2::llama_sampler_chain_init( llama_cpp_sys_2::llama_sampler_chain_params { no_perf }, @@ -76,9 +76,10 @@ impl LlamaSampler { } } + /// Same as [`Self::chain`] with `no_perf = false`. #[must_use] - pub fn chain(samplers: Vec) -> Self { - Self::chain_with_no_perf(samplers, false) + pub fn chain_simple(samplers: Vec) -> Self { + Self::chain(samplers, false) } #[must_use] @@ -196,6 +197,8 @@ impl LlamaSampler { Self { sampler } } + /// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id` + /// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`. #[must_use] pub fn penalties_simple( model: &LlamaModel, From 95c2c87a33dcfa87d5470340f48526be41c70bce Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 8 Dec 2024 14:03:50 -0600 Subject: [PATCH 12/16] Generalize the arguments to LlamaSampler::chain --- examples/simple/src/main.rs | 2 +- llama-cpp-2/src/sampling.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index f67a5309..f276ac24 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -245,7 +245,7 @@ either reduce n_len or increase n_ctx" // The `Decoder` let mut decoder = encoding_rs::UTF_8.new_decoder(); - let mut sampler = LlamaSampler::chain(vec![ + let mut sampler = LlamaSampler::chain_simple([ LlamaSampler::dist(seed.unwrap_or(1234)), LlamaSampler::greedy(), ]); diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 67a2b57b..4ca1db89 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -58,7 +58,7 @@ impl LlamaSampler { } #[must_use] - pub fn chain(samplers: Vec, no_perf: bool) -> Self { + pub fn chain(samplers: impl IntoIterator, no_perf: bool) -> Self { unsafe { let chain = llama_cpp_sys_2::llama_sampler_chain_init( llama_cpp_sys_2::llama_sampler_chain_params { no_perf }, @@ -78,7 +78,7 @@ impl LlamaSampler { /// Same as [`Self::chain`] with `no_perf = false`. #[must_use] - pub fn chain_simple(samplers: Vec) -> Self { + pub fn chain_simple(samplers: impl IntoIterator) -> Self { Self::chain(samplers, false) } From 76fd77647174129c94593fe3513caa0d6aa0b801 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 8 Dec 2024 19:53:03 -0600 Subject: [PATCH 13/16] Generalize the arguments to accept_many and with_tokens --- llama-cpp-2/src/sampling.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 4ca1db89..7a64991f 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -43,16 +43,16 @@ impl LlamaSampler { /// Accepts several tokens from the sampler or context, possibly updating the internal state of /// certain samplers (e.g. grammar, repetition, etc.) - pub fn accept_many(&mut self, tokens: &[LlamaToken]) { + pub fn accept_many(&mut self, tokens: impl IntoIterator>) { for token in tokens { - unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.0) } + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.as_ref().0) } } } /// Accepts several tokens from the sampler or context, possibly updating the internal state of /// certain samplers (e.g. grammar, repetition, etc.) #[must_use] - pub fn with_tokens(mut self, tokens: &[LlamaToken]) -> Self { + pub fn with_tokens(mut self, tokens: impl IntoIterator>) -> Self { self.accept_many(tokens); self } From 8096e79b8ad449393d8db6ed826d6dc9aa32d3b4 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 8 Dec 2024 20:35:43 -0600 Subject: [PATCH 14/16] Documentation for sampler creation methods --- llama-cpp-2/src/sampling.rs | 156 ++++++++++++++++++++++++++++++++++-- 1 file changed, 148 insertions(+), 8 deletions(-) diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 7a64991f..abe67352 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -57,6 +57,12 @@ impl LlamaSampler { self } + /// Combines a list of samplers into a single sampler that applies each component sampler one + /// after another. + /// + /// If you are using a chain to select a token, the chain should always end with one of + /// [`LlamaSampler::greedy`], [`LlamaSampler::dist`], [`LlamaSampler::mirostat`], and + /// [`LlamaSampler::mirostat_v2`]. #[must_use] pub fn chain(samplers: impl IntoIterator, no_perf: bool) -> Self { unsafe { @@ -82,48 +88,108 @@ impl LlamaSampler { Self::chain(samplers, false) } + /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original + /// value, the rest are set to -inf + /// + /// # Example: + /// ```rust + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// use llama_cpp_2::sampling::LlamaSampler; + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0., 0.), + /// LlamaTokenData::new(LlamaToken(1), 1., 0.), + /// LlamaTokenData::new(LlamaToken(2), 2., 0.), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::temp(0.5)); + /// + /// assert_eq!(data_array.data[0].logit(), 0.); + /// assert_eq!(data_array.data[1].logit(), 2.); + /// assert_eq!(data_array.data[2].logit(), 4.); + /// ``` #[must_use] pub fn temp(t: f32) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp(t) }; Self { sampler } } + /// Dynamic temperature implementation (a.k.a. entropy) described in the paper + /// . #[must_use] pub fn temp_ext(t: f32, delta: f32, exponent: f32) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_temp_ext(t, delta, exponent) }; Self { sampler } } + /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" + /// + /// + /// # Example: + /// ```rust + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// use llama_cpp_2::sampling::LlamaSampler; + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0., 0.), + /// LlamaTokenData::new(LlamaToken(1), 1., 0.), + /// LlamaTokenData::new(LlamaToken(2), 2., 0.), + /// LlamaTokenData::new(LlamaToken(3), 3., 0.), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::top_k(2)); + /// + /// assert_eq!(data_array.data.len(), 2); + /// assert_eq!(data_array.data[0].id(), LlamaToken(3)); + /// assert_eq!(data_array.data[1].id(), LlamaToken(2)); + /// ``` #[must_use] pub fn top_k(k: i32) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_k(k) }; Self { sampler } } + /// Locally Typical Sampling implementation described in the paper . #[must_use] pub fn typical(p: f32, min_keep: usize) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_typical(p, min_keep) }; Self { sampler } } + /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" + /// #[must_use] pub fn top_p(p: f32, min_keep: usize) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_top_p(p, min_keep) }; Self { sampler } } + /// Minimum P sampling as described in #[must_use] pub fn min_p(p: f32, min_keep: usize) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_min_p(p, min_keep) }; Self { sampler } } + /// XTC sampler as described in #[must_use] pub fn xtc(p: f32, t: f32, min_keep: usize, seed: u32) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_xtc(p, t, min_keep, seed) }; Self { sampler } } + /// Grammar sampler + /// + /// # Panics + /// If either of ``grammar_str`` or ``grammar_root`` contain null bytes. #[must_use] pub fn grammar(model: &LlamaModel, grammar_str: &str, grammar_root: &str) -> Self { let grammar_str = CString::new(grammar_str).unwrap(); @@ -139,6 +205,13 @@ impl LlamaSampler { Self { sampler } } + /// DRY sampler, designed by p-e-w, as described in: + /// , porting Koboldcpp + /// implementation authored by pi6am: + /// + /// # Panics + /// If any string in ``seq_breakers`` contains null bytes. + #[allow(missing_docs)] #[must_use] pub fn dry( model: &LlamaModel, @@ -146,15 +219,16 @@ impl LlamaSampler { base: f32, allowed_length: i32, penalty_last_n: i32, - seq_breakers: &[impl AsRef<[u8]>], + seq_breakers: impl IntoIterator>, ) -> Self { + let seq_breakers: Vec = seq_breakers + .into_iter() + .map(|s| CString::new(s.as_ref()).unwrap()) + .collect(); + let mut seq_breaker_pointers: Vec<*const i8> = + seq_breakers.iter().map(|s| s.as_ptr()).collect(); + let sampler = unsafe { - let seq_breakers: Vec = seq_breakers - .iter() - .map(|s| CString::new(s.as_ref()).unwrap()) - .collect(); - let mut seq_breaker_pointers: Vec<*const i8> = - seq_breakers.iter().map(|s| s.as_ptr()).collect(); llama_cpp_sys_2::llama_sampler_init_dry( model.model.as_ptr(), multiplier, @@ -168,6 +242,18 @@ impl LlamaSampler { Self { sampler } } + /// Penalizes tokens for being present in the context. + /// + /// Parameters: + /// - ``n_vocab``: [`LlamaModel::n_vocab`] + /// - ``special_eos)id``: [`LlamaModel::token_eos`] + /// - ``linefeed_id``: [`LlamaModel::token_nl`] + /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size) + /// - ``penalty_repeat``: 1.0 = disabled + /// - ``penalty_freq``: 0.0 = disabled + /// - ``penalty_present``: 0.0 = disabled + /// - ``penalize_nl``: consider newlines as a repeatable token + /// - ``ignore_eos``: ignore the end-of-sequence token #[allow(clippy::too_many_arguments)] #[must_use] pub fn penalties( @@ -199,6 +285,13 @@ impl LlamaSampler { /// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id` /// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`. + /// + /// Parameters: + /// - ``model``: The model's tokenizer to use to initialize the sampler. + /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size) + /// - ``penalty_repeat``: 1.0 = disabled + /// - ``penalty_freq``: 0.0 = disabled + /// - ``penalty_present``: 0.0 = disabled #[must_use] pub fn penalties_simple( model: &LlamaModel, @@ -220,6 +313,21 @@ impl LlamaSampler { ) } + /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. + /// + /// # Parameters: + /// - ``n_vocab``: [`LlamaModel::n_vocab`] + /// - ``seed``: Seed to initialize random generation with. + /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the + /// generated text. A higher value corresponds to more surprising or less predictable text, + /// while a lower value corresponds to less surprising or more predictable text. + /// - ``eta``: The learning rate used to update `mu` based on the error between the target and + /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be + /// updated more quickly, while a smaller learning rate will result in slower updates. + /// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary + /// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. + /// In the paper, they use `m = 100`, but you can experiment with different values to see how + /// it affects the performance of the algorithm. #[must_use] pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self { let sampler = @@ -227,18 +335,50 @@ impl LlamaSampler { Self { sampler } } + /// Mirostat 2.0 algorithm described in the paper . Uses tokens instead of words. + /// + /// # Parameters: + /// - ``seed``: Seed to initialize random generation with. + /// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the + /// generated text. A higher value corresponds to more surprising or less predictable text, + /// while a lower value corresponds to less surprising or more predictable text. + /// - ``eta``: The learning rate used to update `mu` based on the error between the target and + /// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be + /// updated more quickly, while a smaller learning rate will result in slower updates. #[must_use] pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) }; Self { sampler } } + /// Selects a token at random based on each token's probabilities #[must_use] pub fn dist(seed: u32) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dist(seed) }; Self { sampler } } - + + /// Selects the most likely token + /// + /// # Example: + /// ```rust + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// use llama_cpp_2::sampling::LlamaSampler; + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0., 0.), + /// LlamaTokenData::new(LlamaToken(1), 1., 0.), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::greedy()); + /// + /// assert_eq!(data_array.data.len(), 2); + /// assert_eq!(data_array.selected_token(), Some(LlamaToken(1))); + /// ``` #[must_use] pub fn greedy() -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() }; From aeb76dceb622eea90764f9a9ad89b666c8583274 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 8 Dec 2024 20:49:51 -0600 Subject: [PATCH 15/16] Add LlamaTokenDataArray::with_sampler; use Borrow instead of AsRef for LlamaToken --- llama-cpp-2/src/sampling.rs | 7 ++++--- llama-cpp-2/src/token/data_array.rs | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index abe67352..8781ff48 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -1,5 +1,6 @@ //! Safe wrapper around `llama_sampler`. +use std::borrow::Borrow; use std::ffi::CString; use std::fmt::{Debug, Formatter}; @@ -43,16 +44,16 @@ impl LlamaSampler { /// Accepts several tokens from the sampler or context, possibly updating the internal state of /// certain samplers (e.g. grammar, repetition, etc.) - pub fn accept_many(&mut self, tokens: impl IntoIterator>) { + pub fn accept_many(&mut self, tokens: impl IntoIterator>) { for token in tokens { - unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.as_ref().0) } + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) } } } /// Accepts several tokens from the sampler or context, possibly updating the internal state of /// certain samplers (e.g. grammar, repetition, etc.) #[must_use] - pub fn with_tokens(mut self, tokens: impl IntoIterator>) -> Self { + pub fn with_tokens(mut self, tokens: impl IntoIterator>) -> Self { self.accept_many(tokens); self } diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 0912af8a..3f75ee8f 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -132,7 +132,17 @@ impl LlamaTokenDataArray { } } + /// Modifies the data array by applying a sampler to it + #[must_use] + pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self { + self.apply_sampler(sampler); + self + } + /// Randomly selects a token from the candidates based on their probabilities. + /// + /// # Panics + /// If the internal llama.cpp sampler fails to select a token. pub fn sample_token(&mut self, seed: u32) -> LlamaToken { self.apply_sampler(&mut LlamaSampler::dist(seed)); self.selected_token() @@ -140,6 +150,9 @@ impl LlamaTokenDataArray { } /// Selects the token with the highest probability. + /// + /// # Panics + /// If the internal llama.cpp sampler fails to select a token. pub fn sample_token_greedy(&mut self) -> LlamaToken { self.apply_sampler(&mut LlamaSampler::greedy()); self.selected_token() From 67ea6889c8ce4ed03423e9cbd4684d697a26fa0f Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 8 Dec 2024 22:15:42 -0600 Subject: [PATCH 16/16] Test for LlamaSampler::chain_simple --- llama-cpp-2/src/sampling.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 8781ff48..69a8554f 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -84,6 +84,34 @@ impl LlamaSampler { } /// Same as [`Self::chain`] with `no_perf = false`. + /// + /// # Example + /// ```rust + /// use llama_cpp_2::token::{ + /// LlamaToken, + /// data::LlamaTokenData, + /// data_array::LlamaTokenDataArray + /// }; + /// use llama_cpp_2::sampling::LlamaSampler; + /// + /// let mut data_array = LlamaTokenDataArray::new(vec![ + /// LlamaTokenData::new(LlamaToken(0), 0., 0.), + /// LlamaTokenData::new(LlamaToken(1), 1., 0.), + /// LlamaTokenData::new(LlamaToken(2), 2., 0.), + /// ], false); + /// + /// data_array.apply_sampler(&mut LlamaSampler::chain_simple([ + /// LlamaSampler::temp(0.5), + /// LlamaSampler::greedy(), + /// ])); + /// + /// assert_eq!(data_array.data[0].logit(), 0.); + /// assert_eq!(data_array.data[1].logit(), 2.); + /// assert_eq!(data_array.data[2].logit(), 4.); + /// + /// assert_eq!(data_array.data.len(), 3); + /// assert_eq!(data_array.selected_token(), Some(LlamaToken(2))); + /// ``` #[must_use] pub fn chain_simple(samplers: impl IntoIterator) -> Self { Self::chain(samplers, false)