diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 06059bd..14eca8b 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -41,6 +41,49 @@ impl From for i32 { } } +/// A rusty wrapper around `LLAMA_POOLING_TYPE`. +#[repr(i8)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum LlamaPoolingType { + /// The pooling type is unspecified + Unspecified = -1, + /// No pooling + None = 0, + /// Mean pooling + Mean = 1, + /// CLS pooling + Cls = 2, + /// Last pooling + Last = 3, +} + +/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if +/// the value is not recognized. +impl From for LlamaPoolingType { + fn from(value: i32) -> Self { + match value { + 0 => Self::None, + 1 => Self::Mean, + 2 => Self::Cls, + 3 => Self::Last, + _ => Self::Unspecified, + } + } +} + +/// Create a `c_int` from a `LlamaPoolingType`. +impl From for i32 { + fn from(value: LlamaPoolingType) -> Self { + match value { + LlamaPoolingType::None => 0, + LlamaPoolingType::Mean => 1, + LlamaPoolingType::Cls => 2, + LlamaPoolingType::Last => 3, + LlamaPoolingType::Unspecified => -1, + } + } +} + /// A safe wrapper around `llama_context_params`. /// /// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods. @@ -471,6 +514,35 @@ impl LlamaContextParams { self.context_params.cb_eval_user_data = cb_eval_user_data; self } + + /// Set the type of pooling. + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType}; + /// let params = LlamaContextParams::default() + /// .with_pooling_type(LlamaPoolingType::Last); + /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last); + /// ``` + #[must_use] + pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self { + self.context_params.pooling_type = i32::from(pooling_type); + self + } + + /// Get the type of pooling. + /// + /// # Examples + /// + /// ```rust + /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); + /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified); + /// ``` + #[must_use] + pub fn pooling_type(&self) -> LlamaPoolingType { + LlamaPoolingType::from(self.context_params.pooling_type) + } } /// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)