Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context and model enhancements #510

Merged
4 changes: 2 additions & 2 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ either reduce n_len or increase n_ctx"
while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates = ctx.candidates();

let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);

// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);

// is it an end of stream?
if new_token_id == model.token_eos() {
if model.is_eog_token(new_token_id) {
eprintln!();
break;
}
Expand Down
108 changes: 88 additions & 20 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@

use crate::context::LlamaContext;
use std::ffi::c_int;
use std::num::NonZeroU8;
use std::num::{NonZeroU8, TryFromIntError};

/// Errors that can occur when attempting to prepare values for the kv cache
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum KvCacheConversionError {
/// Sequence id conversion to i32 failed
#[error("Provided sequence id is too large for a i32")]
SeqIdTooLarge(#[source] TryFromIntError),
/// Position 0 conversion to i32 failed
#[error("Provided start position is too large for a i32")]
P0TooLarge(#[source] TryFromIntError),
/// Position 1 conversion to i32 failed
#[error("Provided end position is too large for a i32")]
P1TooLarge(#[source] TryFromIntError),
}

impl LlamaContext<'_> {
/// Copy the cache from one sequence to another.
Expand All @@ -18,33 +32,63 @@ impl LlamaContext<'_> {

/// Copy the cache from one sequence to another.
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If the either position exceeds
/// the maximum i32 value, no copy is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `src` - The sequence id to copy the cache from.
/// * `dest` - The sequence id to copy the cache to.
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
pub fn copy_kv_cache_seq(
&mut self,
src: i32,
dest: i32,
p0: Option<u32>,
p1: Option<u32>,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
}
Ok(())
}

/// Clear the kv cache for the given sequence.
/// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
/// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If the sequence id or
/// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `src` - The sequence id to clear the cache for.
/// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
}
pub fn clear_kv_cache_seq(
&mut self,
src: Option<u32>,
p0: Option<u32>,
p1: Option<u32>,
) -> Result<bool, KvCacheConversionError> {
let src = src
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?;
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) })
}

/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
Expand Down Expand Up @@ -73,25 +117,44 @@ impl LlamaContext<'_> {
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `delta` - The relative position to add to the tokens
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
pub fn kv_cache_seq_add(
&mut self,
seq_id: i32,
p0: Option<u32>,
p1: Option<u32>,
delta: i32,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
}
Ok(())
}

/// Integer division of the positions by factor of `d > 1`
/// If the KV cache is `RoPEd`, the KV data is updated accordingly:
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
Expand All @@ -101,14 +164,19 @@ impl LlamaContext<'_> {
pub fn kv_cache_seq_div(
&mut self,
seq_id: i32,
p0: Option<u16>,
p1: Option<u16>,
p0: Option<u32>,
p1: Option<u32>,
d: NonZeroU8,
) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
let d = c_int::from(d.get());
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
Ok(())
}

/// Returns the largest position present in the KV cache for the specified sequence
Expand Down
60 changes: 60 additions & 0 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,66 @@ impl LlamaContextParams {
self.context_params.n_ubatch
}

/// Set the `flash_attention` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_flash_attention(true);
/// assert_eq!(params.flash_attention(), true);
/// ```
#[must_use]
pub fn with_flash_attention(mut self, enabled: bool) -> Self {
self.context_params.flash_attn = enabled;
self
}

/// Get the `flash_attention` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.flash_attention(), false);
/// ```
#[must_use]
pub fn flash_attention(&self) -> bool {
self.context_params.flash_attn
}

/// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_offload_kqv(false);
/// assert_eq!(params.offload_kqv(), false);
/// ```
#[must_use]
pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
self.context_params.offload_kqv = enabled;
self
}

/// Get the `offload_kqv` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.offload_kqv(), true);
/// ```
#[must_use]
pub fn offload_kqv(&self) -> bool {
self.context_params.offload_kqv
}

/// Set the type of rope scaling.
///
/// # Examples
Expand Down
6 changes: 6 additions & 0 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ impl LlamaModel {
LlamaToken(token)
}

/// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
#[must_use]
pub fn is_eog_token(&self, token: LlamaToken) -> bool {
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) }
}

/// Get the decoder start token token.
#[must_use]
pub fn decode_start_token(&self) -> LlamaToken {
Expand Down
33 changes: 33 additions & 0 deletions llama-cpp-2/src/token/data_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,37 @@ impl LlamaTokenDataArray {
*mu = unsafe { *mu_ptr };
LlamaToken(token)
}

/// Mirostat 1.0 algorithm described in the [paper](https://arxiv.org/abs/2007.14966). Uses tokens instead of words.
///
/// # Parameters
///
/// * `tau` The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// * `eta` The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// * `m` The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// * `mu` Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
pub fn sample_token_mirostat_v1(
&mut self,
ctx: &mut LlamaContext,
tau: f32,
eta: f32,
m: i32,
mu: &mut f32,
) -> LlamaToken {
let mu_ptr = ptr::from_mut(mu);
let token = unsafe {
self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_token_mirostat(
ctx.context.as_ptr(),
c_llama_token_data_array,
tau,
eta,
m,
mu_ptr,
)
})
};
*mu = unsafe { *mu_ptr };
LlamaToken(token)
}
}
Loading