From aeb76dceb622eea90764f9a9ad89b666c8583274 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 8 Dec 2024 20:49:51 -0600 Subject: [PATCH] Add LlamaTokenDataArray::with_sampler; use Borrow instead of AsRef for LlamaToken --- llama-cpp-2/src/sampling.rs | 7 ++++--- llama-cpp-2/src/token/data_array.rs | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index abe6735..8781ff4 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -1,5 +1,6 @@ //! Safe wrapper around `llama_sampler`. +use std::borrow::Borrow; use std::ffi::CString; use std::fmt::{Debug, Formatter}; @@ -43,16 +44,16 @@ impl LlamaSampler { /// Accepts several tokens from the sampler or context, possibly updating the internal state of /// certain samplers (e.g. grammar, repetition, etc.) - pub fn accept_many(&mut self, tokens: impl IntoIterator>) { + pub fn accept_many(&mut self, tokens: impl IntoIterator>) { for token in tokens { - unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.as_ref().0) } + unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) } } } /// Accepts several tokens from the sampler or context, possibly updating the internal state of /// certain samplers (e.g. grammar, repetition, etc.) #[must_use] - pub fn with_tokens(mut self, tokens: impl IntoIterator>) -> Self { + pub fn with_tokens(mut self, tokens: impl IntoIterator>) -> Self { self.accept_many(tokens); self } diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 0912af8..3f75ee8 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -132,7 +132,17 @@ impl LlamaTokenDataArray { } } + /// Modifies the data array by applying a sampler to it + #[must_use] + pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self { + self.apply_sampler(sampler); + self + } + /// Randomly selects a token from the candidates based on their probabilities. + /// + /// # Panics + /// If the internal llama.cpp sampler fails to select a token. pub fn sample_token(&mut self, seed: u32) -> LlamaToken { self.apply_sampler(&mut LlamaSampler::dist(seed)); self.selected_token() @@ -140,6 +150,9 @@ impl LlamaTokenDataArray { } /// Selects the token with the highest probability. + /// + /// # Panics + /// If the internal llama.cpp sampler fails to select a token. pub fn sample_token_greedy(&mut self) -> LlamaToken { self.apply_sampler(&mut LlamaSampler::greedy()); self.selected_token()