From 25e87b6fba62bdaee42a03e85176e7846168ee66 Mon Sep 17 00:00:00 2001 From: Britt Lewis Date: Tue, 24 Sep 2024 21:56:59 -0400 Subject: [PATCH] Expose flash attention --- llama-cpp-2/src/context/params.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 9f77834..5048010 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -197,6 +197,36 @@ impl LlamaContextParams { self.context_params.n_ubatch } + /// Set the `flash_attention` parameter + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default() + /// .with_flash_attention(true); + /// assert_eq!(params.flash_attention(), true); + /// ``` + #[must_use] + pub fn with_flash_attention(mut self, enabled: bool) -> Self { + self.context_params.flash_attn = enabled; + self + } + + /// Get the `flash_attention` parameter + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default(); + /// assert_eq!(params.flash_attention(), false); + /// ``` + #[must_use] + pub fn flash_attention(&self) -> bool { + self.context_params.flash_attn + } + /// Set the type of rope scaling. /// /// # Examples