Skip to content

Commit

Permalink
fixed up LlamaContextParams with new CB
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Jan 21, 2024
1 parent 40d4b04 commit ce6eb1b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 145 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 3 additions & 5 deletions llama-cpp-2/examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ fn main() -> Result<()> {
.with_context(|| "unable to load model")?;

// initialize the context
let ctx_params = LlamaContextParams {
seed: 1234,
n_ctx: NonZeroU32::new(2048),
..LlamaContextParams::default()
};
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(2048))
.with_seed(1234);

let mut ctx = model.new_context(&backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;
Expand Down
227 changes: 95 additions & 132 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! A safe wrapper around `llama_context_params`.
use llama_cpp_sys_2::{ggml_type, llama_context_params};
use llama_cpp_sys_2;
use std::fmt::Debug;
use std::num::NonZeroU32;

Expand Down Expand Up @@ -43,152 +43,115 @@ impl From<RopeScalingType> for i8 {
}

/// A safe wrapper around `llama_context_params`.
#[derive(Debug, PartialEq)]
///
/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
///
/// # Examples
///
/// ```rust
/// # use std::num::NonZeroU32;
/// use llama_cpp_2::context::params::LlamaContextParams;
///
///let ctx_params = LlamaContextParams::default()
/// .with_n_ctx(NonZeroU32::new(2048))
/// .with_seed(1234);
///
/// assert_eq!(ctx_params.seed(), 1234);
/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
/// ```
#[derive(Debug, Clone)]
#[allow(
missing_docs,
clippy::struct_excessive_bools,
clippy::module_name_repetitions
)]
pub struct LlamaContextParams {
/// The random seed
pub seed: u32,
/// the number of tokens in the context - [`None`] if defined by the model.
pub n_ctx: Option<NonZeroU32>,
pub n_batch: u32,
pub n_threads: u32,
pub n_threads_batch: u32,
pub rope_scaling_type: RopeScalingType,
pub rope_freq_base: f32,
pub rope_freq_scale: f32,
pub yarn_ext_factor: f32,
pub yarn_attn_factor: f32,
pub yarn_beta_fast: f32,
pub yarn_beta_slow: f32,
pub yarn_orig_ctx: u32,
pub type_k: ggml_type,
pub type_v: ggml_type,
pub mul_mat_q: bool,
pub logits_all: bool,
pub embedding: bool,
pub offload_kqv: bool,
pub cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
pub cb_eval_user_data: *mut std::ffi::c_void,
pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
}

impl LlamaContextParams {
/// Set the seed of the context
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// let params = params.with_seed(1234);
/// assert_eq!(params.seed(), 1234);
/// ```
pub fn with_seed(mut self, seed: u32) -> Self {
self.context_params.seed = seed;
self
}

/// Get the seed of the context
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_seed(1234);
/// assert_eq!(params.seed(), 1234);
/// ```
pub fn seed(&self) -> u32 {
self.context_params.seed
}

/// Set the side of the context
///
/// # Examples
///
/// ```rust
/// # use std::num::NonZeroU32;
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// let params = params.with_n_ctx(NonZeroU32::new(2048));
/// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
/// ```
pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
self.context_params.n_ctx = n_ctx.map_or(0, |n_ctx| n_ctx.get());
self
}

/// Get the size of the context.
///
/// [`None`] if the context size is specified by the model and not the context.
///
/// # Examples
///
/// ```rust
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
pub fn n_ctx(&self) -> Option<NonZeroU32> {
NonZeroU32::new(self.context_params.n_ctx)
}

/// Get the type of rope scaling.
///
/// # Examples
///
/// ```rust
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
/// ```
pub fn rope_scaling_type(&self) -> RopeScalingType {
RopeScalingType::from(self.context_params.rope_scaling_type)
}
}

/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
/// ```
/// # use std::num::NonZeroU32;
/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
/// let params = LlamaContextParams::default();
/// assert_eq!(params.n_ctx, NonZeroU32::new(512), "n_ctx should be 512");
/// assert_eq!(params.rope_scaling_type, RopeScalingType::Unspecified);
/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
/// ```
impl Default for LlamaContextParams {
fn default() -> Self {
Self::from(unsafe { llama_cpp_sys_2::llama_context_default_params() })
}
}

impl From<llama_context_params> for LlamaContextParams {
fn from(
llama_context_params {
seed,
n_ctx,
n_batch,
n_threads,
n_threads_batch,
rope_freq_base,
rope_freq_scale,
cb_eval,
cb_eval_user_data,
type_k,
type_v,
mul_mat_q,
logits_all,
embedding,
rope_scaling_type,
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
}: llama_context_params,
) -> Self {
Self {
seed,
n_ctx: NonZeroU32::new(n_ctx),
n_batch,
n_threads,
n_threads_batch,
rope_freq_base,
rope_freq_scale,
type_k,
type_v,
mul_mat_q,
logits_all,
embedding,
rope_scaling_type: RopeScalingType::from(rope_scaling_type),
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}
let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
Self { context_params, }
}
}

impl From<LlamaContextParams> for llama_context_params {
fn from(
LlamaContextParams {
seed,
n_ctx,
n_batch,
n_threads,
n_threads_batch,
rope_freq_base,
rope_freq_scale,
type_k,
type_v,
mul_mat_q,
logits_all,
embedding,
rope_scaling_type,
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}: LlamaContextParams,
) -> Self {
llama_context_params {
seed,
n_ctx: n_ctx.map_or(0, NonZeroU32::get),
n_batch,
n_threads,
n_threads_batch,
rope_freq_base,
rope_freq_scale,
type_k,
type_v,
mul_mat_q,
logits_all,
embedding,
rope_scaling_type: i8::from(rope_scaling_type),
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}
}
}
11 changes: 5 additions & 6 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::LlamaTokenType;
use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError};
use llama_cpp_sys_2::{llama_context_params, llama_token_get_type, llama_vocab_type};
use std::ffi::CString;
use std::os::raw::c_int;
use std::path::Path;
Expand Down Expand Up @@ -184,7 +183,7 @@ impl LlamaModel {
/// If the token type is not known to this library.
#[must_use]
pub fn token_type(&self, LlamaToken(id): LlamaToken) -> LlamaTokenType {
let token_type = unsafe { llama_token_get_type(self.model.as_ptr(), id) };
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_type(self.model.as_ptr(), id) };
LlamaTokenType::try_from(token_type).expect("token type is valid")
}

Expand Down Expand Up @@ -314,7 +313,7 @@ impl LlamaModel {
_: &LlamaBackend,
params: LlamaContextParams,
) -> Result<LlamaContext, LlamaContextLoadError> {
let context_params = llama_context_params::from(params);
let context_params = params.context_params;
let context = unsafe {
llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
};
Expand Down Expand Up @@ -345,13 +344,13 @@ pub enum VocabType {
pub enum LlamaTokenTypeFromIntError {
/// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
#[error("Unknown Value {0}")]
UnknownValue(llama_vocab_type),
UnknownValue(llama_cpp_sys_2::llama_vocab_type),
}

impl TryFrom<llama_vocab_type> for VocabType {
impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
type Error = LlamaTokenTypeFromIntError;

fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
match value {
llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
Expand Down

0 comments on commit ce6eb1b

Please sign in to comment.