Skip to content

Commit

Permalink
Expose n_ubatch context param
Browse files Browse the repository at this point in the history
* n_batch is responsible for max number of tokens
  llama_decode can accept in a single call (a single
  "batch")
* n_ubatch is lower level, corresponding to hardware
  batch size during decoding. must be less than or
  equal to n_batch.
  - ggerganov/llama.cpp#6328 (comment)
  - https://github.com/ggerganov/llama.cpp/blob/557410b8f06380560155ac7fcb8316d71ddc9837/common/common.h#L58
  • Loading branch information
brittlewis12 committed Sep 22, 2024
1 parent ea798fa commit 56625a6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
8 changes: 7 additions & 1 deletion llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ impl<'model> LlamaContext<'model> {
}
}

/// Gets the max number of tokens in a batch.
/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch.
#[must_use]
pub fn n_batch(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
}

/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch.
#[must_use]
pub fn n_ubatch(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
}

/// Gets the size of the context.
#[must_use]
pub fn n_ctx(&self) -> u32 {
Expand Down
31 changes: 31 additions & 0 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,37 @@ impl LlamaContextParams {
self.context_params.n_batch
}

/// Set the `n_ubatch`
///
/// # Examples
///
/// ```rust
/// # use std::num::NonZeroU32;
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_n_ubatch(512);
/// assert_eq!(params.n_ubatch(), 512);
/// ```
#[must_use]
pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
self.context_params.n_ubatch = n_ubatch;
self
}

/// Get the `n_ubatch`
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.n_ubatch(), 512);
/// ```
#[must_use]
pub fn n_ubatch(&self) -> u32 {
self.context_params.n_ubatch
}

/// Set the type of rope scaling.
///
/// # Examples
Expand Down

0 comments on commit 56625a6

Please sign in to comment.