diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 80c6561a..ac6350a0 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -241,6 +241,33 @@ impl LlamaContextParams { pub fn rope_freq_scale(&self) -> f32 { self.context_params.rope_freq_scale } + + /// Get the number of threads. + /// + /// # Examples + /// + /// ```rust + /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); + /// assert_eq!(params.n_threads(), 4); + /// ``` + pub fn n_threads(&self) -> u32 { + self.context_params.n_threads + } + + /// Set the number of threads. + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default() + /// .with_n_threads(8); + /// assert_eq!(params.n_threads(), 8); + /// ``` + pub fn with_n_threads(mut self, n_threads: u32) -> Self { + self.context_params.n_threads = n_threads; + self + } } /// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)