From 16c21f508cb580b2d7dd3376434b8cc90557707d Mon Sep 17 00:00:00 2001 From: Marcus Dunn Date: Wed, 8 May 2024 00:19:30 +0000 Subject: [PATCH 1/2] updated llama.cpp --- llama-cpp-sys-2/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 784e11de..af0a5b61 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 784e11dea1f5ce9638851b2b0dddb107e2a609c8 +Subproject commit af0a5b616359809ce886ea433acedebb39b12969 From e06daf841ab5e0c73ae58816d28c8b28bf5e6ba2 Mon Sep 17 00:00:00 2001 From: marcus Date: Wed, 8 May 2024 12:43:57 -0700 Subject: [PATCH 2/2] changed to new `llama_model_kv_override` API --- llama-cpp-2/src/model/params.rs | 5 +++-- llama-cpp-2/src/model/params/kv_overrides.rs | 23 +++++++++++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/llama-cpp-2/src/model/params.rs b/llama-cpp-2/src/model/params.rs index 92a2ccdb..9fc9d036 100644 --- a/llama-cpp-2/src/model/params.rs +++ b/llama-cpp-2/src/model/params.rs @@ -96,7 +96,7 @@ impl LlamaModelParams { key: [0; 128], tag: 0, __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { - int_value: 0, + val_i64: 0, }, }); @@ -194,11 +194,12 @@ impl Default for LlamaModelParams { let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() }; LlamaModelParams { params: default_params, + // push the next one to ensure we maintain the iterator invariant of ending with a 0 kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override { key: [0; 128], tag: 0, __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { - int_value: 0, + val_i64: 0, }, }], } diff --git a/llama-cpp-2/src/model/params/kv_overrides.rs b/llama-cpp-2/src/model/params/kv_overrides.rs index 5e14af7e..7d10256d 100644 --- a/llama-cpp-2/src/model/params/kv_overrides.rs +++ b/llama-cpp-2/src/model/params/kv_overrides.rs @@ -13,6 +13,8 @@ pub enum ParamOverrideValue { Float(f64), /// A integer value Int(i64), + /// A string value + Str([std::os::raw::c_char; 128]), } impl ParamOverrideValue { @@ -21,21 +23,27 @@ impl ParamOverrideValue { ParamOverrideValue::Bool(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL, ParamOverrideValue::Float(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT, ParamOverrideValue::Int(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT, + ParamOverrideValue::Str(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR, } } pub(crate) fn value(&self) -> llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { match self { ParamOverrideValue::Bool(value) => { - llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { bool_value: *value } + llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_bool: *value } } ParamOverrideValue::Float(value) => { llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { - float_value: *value, + val_f64: *value, } } ParamOverrideValue::Int(value) => { - llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { int_value: *value } + llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_i64: *value } + } + ParamOverrideValue::Str(c_string) => { + llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { + val_str: *c_string, + } } } } @@ -51,13 +59,16 @@ impl From<&llama_cpp_sys_2::llama_model_kv_override> for ParamOverrideValue { ) -> Self { match *tag { llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT => { - ParamOverrideValue::Int(unsafe { __bindgen_anon_1.int_value }) + ParamOverrideValue::Int(unsafe { __bindgen_anon_1.val_i64 }) } llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT => { - ParamOverrideValue::Float(unsafe { __bindgen_anon_1.float_value }) + ParamOverrideValue::Float(unsafe { __bindgen_anon_1.val_f64 }) } llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL => { - ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.bool_value }) + ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.val_bool }) + } + llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR => { + ParamOverrideValue::Str(unsafe { __bindgen_anon_1.val_str }) } _ => unreachable!("Unknown tag of {tag}"), }