Skip to content

Commit

Permalink
Add LlamaTokenDataArray::with_sampler; use Borrow instead of AsRef fo…
Browse files Browse the repository at this point in the history
…r LlamaToken
  • Loading branch information
nkoppel committed Dec 9, 2024
1 parent 7aa4367 commit aeb76dc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
7 changes: 4 additions & 3 deletions llama-cpp-2/src/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Safe wrapper around `llama_sampler`.
use std::borrow::Borrow;
use std::ffi::CString;
use std::fmt::{Debug, Formatter};

Expand Down Expand Up @@ -43,16 +44,16 @@ impl LlamaSampler {

/// Accepts several tokens from the sampler or context, possibly updating the internal state of
/// certain samplers (e.g. grammar, repetition, etc.)
pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl AsRef<LlamaToken>>) {
pub fn accept_many(&mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) {
for token in tokens {
unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.as_ref().0) }
unsafe { llama_cpp_sys_2::llama_sampler_accept(self.sampler, token.borrow().0) }
}
}

/// Accepts several tokens from the sampler or context, possibly updating the internal state of
/// certain samplers (e.g. grammar, repetition, etc.)
#[must_use]
pub fn with_tokens(mut self, tokens: impl IntoIterator<Item = impl AsRef<LlamaToken>>) -> Self {
pub fn with_tokens(mut self, tokens: impl IntoIterator<Item = impl Borrow<LlamaToken>>) -> Self {
self.accept_many(tokens);
self
}
Expand Down
13 changes: 13 additions & 0 deletions llama-cpp-2/src/token/data_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,27 @@ impl LlamaTokenDataArray {
}
}

/// Modifies the data array by applying a sampler to it
#[must_use]
pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
self.apply_sampler(sampler);
self
}

/// Randomly selects a token from the candidates based on their probabilities.
///
/// # Panics
/// If the internal llama.cpp sampler fails to select a token.
pub fn sample_token(&mut self, seed: u32) -> LlamaToken {
self.apply_sampler(&mut LlamaSampler::dist(seed));
self.selected_token()
.expect("Dist sampler failed to select a token!")
}

/// Selects the token with the highest probability.
///
/// # Panics
/// If the internal llama.cpp sampler fails to select a token.
pub fn sample_token_greedy(&mut self) -> LlamaToken {
self.apply_sampler(&mut LlamaSampler::greedy());
self.selected_token()
Expand Down

0 comments on commit aeb76dc

Please sign in to comment.