Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/cargo/thiserror-1.0.64
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn authored Sep 28, 2024
2 parents bf65f31 + 4333caa commit 99d8dd1
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 26 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ hf-hub = { version = "0.3.2" }
criterion = "0.5.1"
pprof = "0.13.0"
bindgen = "0.69.4"
cc = "1.1.14"
cc = "1.1.21"
anyhow = "1.0.86"
clap = "4.5.16"
encoding_rs = "0.8.34"
Expand Down
4 changes: 2 additions & 2 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ either reduce n_len or increase n_ctx"
while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates = ctx.candidates();

let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);

// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);

// is it an end of stream?
if new_token_id == model.token_eos() {
if model.is_eog_token(new_token_id) {
eprintln!();
break;
}
Expand Down
47 changes: 46 additions & 1 deletion llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ impl<'model> LlamaContext<'model> {
}
}

/// Gets the max number of tokens in a batch.
/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch.
#[must_use]
pub fn n_batch(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
}

/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch.
#[must_use]
pub fn n_ubatch(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
}

/// Gets the size of the context.
#[must_use]
pub fn n_ctx(&self) -> u32 {
Expand Down Expand Up @@ -181,6 +187,45 @@ impl<'model> LlamaContext<'model> {
}
}

/// Get the logits for the last token in the context.
///
/// # Returns
/// An iterator over unsorted `LlamaTokenData` containing the
/// logits for the last token in the context.
///
/// # Panics
///
/// - underlying logits data is null
pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
(0_i32..).zip(self.get_logits()).map(|(i, logit)| {
let token = LlamaToken::new(i);
LlamaTokenData::new(token, *logit, 0_f32)
})
}

/// Token logits obtained from the last call to `decode()`.
/// The logits for which `batch.logits[i] != 0` are stored contiguously
/// in the order they have appeared in the batch.
/// Rows: number of tokens for which `batch.logits[i] != 0`
/// Cols: `n_vocab`
///
/// # Returns
///
/// A slice containing the logits for the last decoded token.
/// The size corresponds to the `n_vocab` parameter of the context's model.
///
/// # Panics
///
/// - `n_vocab` does not fit into a usize
/// - token data returned is null
pub fn get_logits(&self) -> &[f32] {
let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
assert!(!data.is_null(), "logits data for last token is null");
let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");

unsafe { slice::from_raw_parts(data, len) }
}

/// Get the logits for the ith token in the context.
///
/// # Panics
Expand Down
108 changes: 88 additions & 20 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@
use crate::context::LlamaContext;
use std::ffi::c_int;
use std::num::NonZeroU8;
use std::num::{NonZeroU8, TryFromIntError};

/// Errors that can occur when attempting to prepare values for the kv cache
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum KvCacheConversionError {
/// Sequence id conversion to i32 failed
#[error("Provided sequence id is too large for a i32")]
SeqIdTooLarge(#[source] TryFromIntError),
/// Position 0 conversion to i32 failed
#[error("Provided start position is too large for a i32")]
P0TooLarge(#[source] TryFromIntError),
/// Position 1 conversion to i32 failed
#[error("Provided end position is too large for a i32")]
P1TooLarge(#[source] TryFromIntError),
}

impl LlamaContext<'_> {
/// Copy the cache from one sequence to another.
Expand All @@ -18,33 +32,63 @@ impl LlamaContext<'_> {

/// Copy the cache from one sequence to another.
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If the either position exceeds
/// the maximum i32 value, no copy is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `src` - The sequence id to copy the cache from.
/// * `dest` - The sequence id to copy the cache to.
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
pub fn copy_kv_cache_seq(
&mut self,
src: i32,
dest: i32,
p0: Option<u32>,
p1: Option<u32>,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
}
Ok(())
}

/// Clear the kv cache for the given sequence.
/// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
/// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If the sequence id or
/// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `src` - The sequence id to clear the cache for.
/// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
}
pub fn clear_kv_cache_seq(
&mut self,
src: Option<u32>,
p0: Option<u32>,
p1: Option<u32>,
) -> Result<bool, KvCacheConversionError> {
let src = src
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?;
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) })
}

/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
Expand Down Expand Up @@ -73,25 +117,44 @@ impl LlamaContext<'_> {
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `delta` - The relative position to add to the tokens
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
pub fn kv_cache_seq_add(
&mut self,
seq_id: i32,
p0: Option<u32>,
p1: Option<u32>,
delta: i32,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
}
Ok(())
}

/// Integer division of the positions by factor of `d > 1`
/// If the KV cache is `RoPEd`, the KV data is updated accordingly:
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
Expand All @@ -101,14 +164,19 @@ impl LlamaContext<'_> {
pub fn kv_cache_seq_div(
&mut self,
seq_id: i32,
p0: Option<u16>,
p1: Option<u16>,
p0: Option<u32>,
p1: Option<u32>,
d: NonZeroU8,
) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
let d = c_int::from(d.get());
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
Ok(())
}

/// Returns the largest position present in the KV cache for the specified sequence
Expand Down
91 changes: 91 additions & 0 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,97 @@ impl LlamaContextParams {
self.context_params.n_batch
}

/// Set the `n_ubatch`
///
/// # Examples
///
/// ```rust
/// # use std::num::NonZeroU32;
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_n_ubatch(512);
/// assert_eq!(params.n_ubatch(), 512);
/// ```
#[must_use]
pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
self.context_params.n_ubatch = n_ubatch;
self
}

/// Get the `n_ubatch`
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.n_ubatch(), 512);
/// ```
#[must_use]
pub fn n_ubatch(&self) -> u32 {
self.context_params.n_ubatch
}

/// Set the `flash_attention` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_flash_attention(true);
/// assert_eq!(params.flash_attention(), true);
/// ```
#[must_use]
pub fn with_flash_attention(mut self, enabled: bool) -> Self {
self.context_params.flash_attn = enabled;
self
}

/// Get the `flash_attention` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.flash_attention(), false);
/// ```
#[must_use]
pub fn flash_attention(&self) -> bool {
self.context_params.flash_attn
}

/// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_offload_kqv(false);
/// assert_eq!(params.offload_kqv(), false);
/// ```
#[must_use]
pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
self.context_params.offload_kqv = enabled;
self
}

/// Get the `offload_kqv` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.offload_kqv(), true);
/// ```
#[must_use]
pub fn offload_kqv(&self) -> bool {
self.context_params.offload_kqv
}

/// Set the type of rope scaling.
///
/// # Examples
Expand Down
6 changes: 6 additions & 0 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ impl LlamaModel {
LlamaToken(token)
}

/// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
#[must_use]
pub fn is_eog_token(&self, token: LlamaToken) -> bool {
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) }
}

/// Get the decoder start token token.
#[must_use]
pub fn decode_start_token(&self) -> LlamaToken {
Expand Down
Loading

0 comments on commit 99d8dd1

Please sign in to comment.