diff --git a/llama-cpp-2/src/llama_backend.rs b/llama-cpp-2/src/llama_backend.rs index e828b73c..5b13087f 100644 --- a/llama-cpp-2/src/llama_backend.rs +++ b/llama-cpp-2/src/llama_backend.rs @@ -43,7 +43,7 @@ impl LlamaBackend { #[tracing::instrument(skip_all)] pub fn init() -> crate::Result { Self::mark_init()?; - unsafe { llama_cpp_sys_2::llama_backend_init(false) } + unsafe { llama_cpp_sys_2::llama_backend_init() } Ok(LlamaBackend {}) } @@ -51,21 +51,78 @@ impl LlamaBackend { /// ``` ///# use llama_cpp_2::llama_backend::LlamaBackend; ///# use std::error::Error; + ///# use llama_cpp_2::llama_backend::NumaStrategy; /// ///# fn main() -> Result<(), Box> { - /// let llama_backend = LlamaBackend::init_numa()?; + /// + /// let llama_backend = LlamaBackend::init_numa(NumaStrategy::MIRROR)?; /// ///# Ok(()) ///# } /// ``` #[tracing::instrument(skip_all)] - pub fn init_numa() -> crate::Result { + pub fn init_numa(strategy: NumaStrategy) -> crate::Result { Self::mark_init()?; - unsafe { llama_cpp_sys_2::llama_backend_init(true) } + unsafe { + llama_cpp_sys_2::llama_numa_init(llama_cpp_sys_2::ggml_numa_strategy::from(strategy)) + } Ok(LlamaBackend {}) } } +/// A rusty wrapper around `numa_strategy`. +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +pub enum NumaStrategy { + /// The numa strategy is disabled. + DISABLED, + /// help wanted: what does this do? + DISTRIBUTE, + /// help wanted: what does this do? + ISOLATE, + /// help wanted: what does this do? + NUMACTL, + /// help wanted: what does this do? + MIRROR, + /// help wanted: what does this do? + COUNT, +} + +/// An invalid numa strategy was provided. +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +pub struct InvalidNumaStrategy( + /// The invalid numa strategy that was provided. + pub llama_cpp_sys_2::ggml_numa_strategy, +); + +impl TryFrom for NumaStrategy { + type Error = InvalidNumaStrategy; + + fn try_from(value: llama_cpp_sys_2::ggml_numa_strategy) -> Result { + match value { + llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED => Ok(Self::DISABLED), + llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE => Ok(Self::DISTRIBUTE), + llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE => Ok(Self::ISOLATE), + llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL => Ok(Self::NUMACTL), + llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR => Ok(Self::MIRROR), + llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT => Ok(Self::COUNT), + value => Err(InvalidNumaStrategy(value)), + } + } +} + +impl From for llama_cpp_sys_2::ggml_numa_strategy { + fn from(value: NumaStrategy) -> Self { + match value { + NumaStrategy::DISABLED => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED, + NumaStrategy::DISTRIBUTE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE, + NumaStrategy::ISOLATE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE, + NumaStrategy::NUMACTL => llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL, + NumaStrategy::MIRROR => llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR, + NumaStrategy::COUNT => llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT, + } + } +} + /// Drops the llama backend. /// ``` /// @@ -92,3 +149,33 @@ impl Drop for LlamaBackend { unsafe { llama_cpp_sys_2::llama_backend_free() } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn numa_from_and_to() { + let numas = [ + NumaStrategy::DISABLED, + NumaStrategy::DISTRIBUTE, + NumaStrategy::ISOLATE, + NumaStrategy::NUMACTL, + NumaStrategy::MIRROR, + NumaStrategy::COUNT, + ]; + + for numa in &numas { + let from = llama_cpp_sys_2::ggml_numa_strategy::from(*numa); + let to = NumaStrategy::try_from(from).expect("Failed to convert from and to"); + assert_eq!(*numa, to); + } + } + + #[test] + fn check_invalid_numa() { + let invalid = 800; + let invalid = NumaStrategy::try_from(invalid); + assert_eq!(invalid, Err(InvalidNumaStrategy(invalid.unwrap_err().0))); + } +} diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 6560bed3..89febfed 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 6560bed3f066c876682464762cad90f1e28e3f1b +Subproject commit 89febfed9322c8849520dc63c93ee4f5fd72556e