diff --git a/llama-cpp-2/src/context/sample.rs b/llama-cpp-2/src/context/sample.rs index a147caf9..2e4a3f06 100644 --- a/llama-cpp-2/src/context/sample.rs +++ b/llama-cpp-2/src/context/sample.rs @@ -184,6 +184,61 @@ impl LlamaContext<'_> { LlamaToken(token) } + /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + pub fn sample_tail_free(&self, token_data: &mut LlamaTokenDataArray, z: f32, min_keep: usize) { + let ctx = self.context.as_ptr(); + unsafe { + token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep); + }); + } + } + + /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + pub fn sample_typical(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { + let ctx = self.context.as_ptr(); + unsafe { + token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } + + /// Nucleus sampling described in academic paper \"The Curious Case of Neural Text Degeneration\" https://arxiv.org/abs/1904.09751" + pub fn sample_top_p(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) { + let ctx = self.context.as_ptr(); + unsafe { + token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } + + /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841" + pub fn sample_min_p( + &self, + llama_token_data: &mut LlamaTokenDataArray, + p: f32, + min_keep: usize, + ) { + let ctx = self.context.as_ptr(); + unsafe { + llama_token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep); + }); + } + } + + /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + pub fn sample_top_k(&self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) { + let ctx = self.context.as_ptr(); + unsafe { + token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { + llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep); + }); + } + } + /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. pub fn sample_token_softmax(&self, token_data: &mut LlamaTokenDataArray) { let ctx = self.context.as_ptr();