From 56625a6cc91c58a15da87c8c84768a7b3b9a7742 Mon Sep 17 00:00:00 2001 From: Britt Lewis Date: Sun, 22 Sep 2024 16:58:55 -0400 Subject: [PATCH] Expose n_ubatch context param * n_batch is responsible for max number of tokens llama_decode can accept in a single call (a single "batch") * n_ubatch is lower level, corresponding to hardware batch size during decoding. must be less than or equal to n_batch. - https://github.com/ggerganov/llama.cpp/discussions/6328#discussioncomment-8919848 - https://github.com/ggerganov/llama.cpp/blob/557410b8f06380560155ac7fcb8316d71ddc9837/common/common.h#L58 --- llama-cpp-2/src/context.rs | 8 +++++++- llama-cpp-2/src/context/params.rs | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index d2ed45be..f5373ab0 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -52,12 +52,18 @@ impl<'model> LlamaContext<'model> { } } - /// Gets the max number of tokens in a batch. + /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch. #[must_use] pub fn n_batch(&self) -> u32 { unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) } } + /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch. + #[must_use] + pub fn n_ubatch(&self) -> u32 { + unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) } + } + /// Gets the size of the context. #[must_use] pub fn n_ctx(&self) -> u32 { diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 93675f8a..9f778343 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -166,6 +166,37 @@ impl LlamaContextParams { self.context_params.n_batch } + /// Set the `n_ubatch` + /// + /// # Examples + /// + /// ```rust + /// # use std::num::NonZeroU32; + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default() + /// .with_n_ubatch(512); + /// assert_eq!(params.n_ubatch(), 512); + /// ``` + #[must_use] + pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self { + self.context_params.n_ubatch = n_ubatch; + self + } + + /// Get the `n_ubatch` + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default(); + /// assert_eq!(params.n_ubatch(), 512); + /// ``` + #[must_use] + pub fn n_ubatch(&self) -> u32 { + self.context_params.n_ubatch + } + /// Set the type of rope scaling. /// /// # Examples