Skip to content

Commit

Permalink
added RopeScalingType
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Jan 17, 2024
1 parent dd1fbea commit 1e713ee
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 42 additions & 3 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,45 @@ use llama_cpp_sys_2::{ggml_type, llama_context_params};
use std::fmt::Debug;
use std::num::NonZeroU32;

/// A rusty wrapper around `rope_scaling_type`.
#[repr(i8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RopeScalingType {
/// The scaling type is unspecified
Unspecified = -1,
/// No scaling
None = 0,
/// Linear scaling
Linear = 1,
/// Yarn scaling
Yarn = 2,
}

/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
/// the value is not recognized.
impl From<i8> for RopeScalingType {
fn from(value: i8) -> Self {
match value {
0 => Self::None,
1 => Self::Linear,
2 => Self::Yarn,
_ => Self::Unspecified,
}
}
}

/// Create a `c_int` from a `RopeScalingType`.
impl From<RopeScalingType> for i8 {
fn from(value: RopeScalingType) -> Self {
match value {
RopeScalingType::None => 0,
RopeScalingType::Linear => 1,
RopeScalingType::Yarn => 2,
RopeScalingType::Unspecified => -1,
}
}
}

/// A safe wrapper around `llama_context_params`.
#[derive(Debug, Clone, Copy, PartialEq)]
#[allow(
Expand All @@ -18,7 +57,7 @@ pub struct LlamaContextParams {
pub n_batch: u32,
pub n_threads: u32,
pub n_threads_batch: u32,
pub rope_scaling_type: i8,
pub rope_scaling_type: RopeScalingType,
pub rope_freq_base: f32,
pub rope_freq_scale: f32,
pub yarn_ext_factor: f32,
Expand Down Expand Up @@ -83,7 +122,7 @@ impl From<llama_context_params> for LlamaContextParams {
mul_mat_q,
logits_all,
embedding,
rope_scaling_type,
rope_scaling_type: RopeScalingType::from(rope_scaling_type),
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
Expand Down Expand Up @@ -131,7 +170,7 @@ impl From<LlamaContextParams> for llama_context_params {
mul_mat_q,
logits_all,
embedding,
rope_scaling_type,
rope_scaling_type: i8::from(rope_scaling_type),
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
Expand Down
1 change: 1 addition & 0 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ pub enum StringToTokenError {
/// let elapsed = end - start;
///
/// assert!(elapsed >= 10)
#[must_use]
pub fn ggml_time_us() -> i64 {
unsafe { llama_cpp_sys_2::ggml_time_us() }
}

0 comments on commit 1e713ee

Please sign in to comment.