Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sampling API back to LlamaTokenDataArray; Add DRY and XTC Samplers #594

Merged
merged 19 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::sampling::params::LlamaSamplerChainParams;
use llama_cpp_2::sampling::params::LlamaSamplerParams;
use llama_cpp_2::sampling::LlamaSampler;

use std::ffi::CString;
Expand Down Expand Up @@ -246,10 +246,12 @@ either reduce n_len or increase n_ctx"
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

let sampler_params = LlamaSamplerChainParams::default();
let mut sampler = LlamaSampler::new(sampler_params)?
.add_dist(seed.unwrap_or(1234))
.add_greedy();
let mut sampler = LlamaSampler::new(LlamaSamplerParams::chain(&[
LlamaSamplerParams::Dist {
seed: seed.unwrap_or(1234),
},
LlamaSamplerParams::Greedy,
]));

while n_cur <= n_len {
// sample the next token
Expand Down
32 changes: 32 additions & 0 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::llama_batch::LlamaBatch;
use crate::model::{LlamaLoraAdapter, LlamaModel};
use crate::timing::LlamaTimings;
use crate::token::data::LlamaTokenData;
use crate::token::data_array::LlamaTokenDataArray;
use crate::token::LlamaToken;
use crate::{
DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
Expand Down Expand Up @@ -202,6 +203,21 @@ impl<'model> LlamaContext<'model> {
})
}

/// Get the token data array for the last token in the context.
///
/// This is a convience method that implements:
/// ```ignore
/// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
/// ```
///
/// # Panics
///
/// - underlying logits data is null
#[must_use]
pub fn token_data_array(&self) -> LlamaTokenDataArray {
LlamaTokenDataArray::from_iter(self.candidates(), false)
}

/// Token logits obtained from the last call to `decode()`.
/// The logits for which `batch.logits[i] != 0` are stored contiguously
/// in the order they have appeared in the batch.
Expand All @@ -217,6 +233,7 @@ impl<'model> LlamaContext<'model> {
///
/// - `n_vocab` does not fit into a usize
/// - token data returned is null
#[must_use]
pub fn get_logits(&self) -> &[f32] {
let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
assert!(!data.is_null(), "logits data for last token is null");
Expand All @@ -237,6 +254,21 @@ impl<'model> LlamaContext<'model> {
})
}

/// Get the token data array for the ith token in the context.
///
/// This is a convience method that implements:
/// ```ignore
/// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false)
/// ```
///
/// # Panics
///
/// - logit `i` is not initialized.
#[must_use]
pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray {
LlamaTokenDataArray::from_iter(self.candidates_ith(i), false)
}

/// Get the logits for the ith token in the context.
///
/// # Panics
Expand Down
8 changes: 0 additions & 8 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,6 @@ pub enum LlamaLoraAdapterRemoveError {
ErrorResult(i32),
}

/// An error that can occur when initializing a sampler.
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaSamplerError {
/// llama.cpp returned null
#[error("null reference from llama.cpp")]
NullReturn,
}

/// get the time (in microseconds) according to llama.cpp
/// ```
/// # use llama_cpp_2::llama_time_us;
Expand Down
Loading
Loading