Skip to content

Commit

Permalink
Expose pooling type
Browse files Browse the repository at this point in the history
  • Loading branch information
LucGeven committed Oct 7, 2024
1 parent 4333caa commit ffbd54c
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,49 @@ impl From<RopeScalingType> for i32 {
}
}

/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
#[repr(i8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum LlamaPoolingType {
/// The pooling type is unspecified
Unspecified = -1,
/// No pooling
None = 0,
/// Mean pooling
Mean = 1,
/// CLS pooling
Cls = 2,
/// Last pooling
Last = 3,
}

/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
/// the value is not recognized.
impl From<i32> for LlamaPoolingType {
fn from(value: i32) -> Self {
match value {
0 => Self::None,
1 => Self::Mean,
2 => Self::Cls,
3 => Self::Last,
_ => Self::Unspecified,
}
}
}

/// Create a `c_int` from a `LlamaPoolingType`.
impl From<LlamaPoolingType> for i32 {
fn from(value: LlamaPoolingType) -> Self {
match value {
LlamaPoolingType::None => 0,
LlamaPoolingType::Mean => 1,
LlamaPoolingType::Cls => 2,
LlamaPoolingType::Last => 3,
LlamaPoolingType::Unspecified => -1,
}
}
}

/// A safe wrapper around `llama_context_params`.
///
/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
Expand Down Expand Up @@ -471,6 +514,35 @@ impl LlamaContextParams {
self.context_params.cb_eval_user_data = cb_eval_user_data;
self
}

/// Set the type of pooling.
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
/// let params = LlamaContextParams::default()
/// .with_pooling_type(LlamaPoolingType::Last);
/// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
/// ```
#[must_use]
pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
self.context_params.pooling_type = i32::from(pooling_type);
self
}

/// Get the type of pooling.
///
/// # Examples
///
/// ```rust
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
/// ```
#[must_use]
pub fn pooling_type(&self) -> LlamaPoolingType {
LlamaPoolingType::from(self.context_params.pooling_type)
}
}

/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
Expand Down

0 comments on commit ffbd54c

Please sign in to comment.