Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed up LlamaContextParams with new CB #36

Merged
merged 1 commit into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 3 additions & 5 deletions llama-cpp-2/examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ fn main() -> Result<()> {
.with_context(|| "unable to load model")?;

// initialize the context
let ctx_params = LlamaContextParams {
seed: 1234,
n_ctx: NonZeroU32::new(2048),
..LlamaContextParams::default()
};
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(2048))
.with_seed(1234);

let mut ctx = model.new_context(&backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;
Expand Down
227 changes: 95 additions & 132 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! A safe wrapper around `llama_context_params`.
use llama_cpp_sys_2::{ggml_type, llama_context_params};
use llama_cpp_sys_2;
use std::fmt::Debug;
use std::num::NonZeroU32;

Expand Down Expand Up @@ -43,152 +43,115 @@ impl From<RopeScalingType> for i8 {
}

/// A safe wrapper around `llama_context_params`.
#[derive(Debug, PartialEq)]
///
/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
///
/// # Examples
///
/// ```rust
/// # use std::num::NonZeroU32;
/// use llama_cpp_2::context::params::LlamaContextParams;
///
///let ctx_params = LlamaContextParams::default()
/// .with_n_ctx(NonZeroU32::new(2048))
/// .with_seed(1234);
///
/// assert_eq!(ctx_params.seed(), 1234);
/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
/// ```
#[derive(Debug, Clone)]
#[allow(
missing_docs,
clippy::struct_excessive_bools,
clippy::module_name_repetitions
)]
pub struct LlamaContextParams {
/// The random seed
pub seed: u32,
/// the number of tokens in the context - [`None`] if defined by the model.
pub n_ctx: Option<NonZeroU32>,
pub n_batch: u32,
pub n_threads: u32,
pub n_threads_batch: u32,
pub rope_scaling_type: RopeScalingType,
pub rope_freq_base: f32,
pub rope_freq_scale: f32,
pub yarn_ext_factor: f32,
pub yarn_attn_factor: f32,
pub yarn_beta_fast: f32,
pub yarn_beta_slow: f32,
pub yarn_orig_ctx: u32,
pub type_k: ggml_type,
pub type_v: ggml_type,
pub mul_mat_q: bool,
pub logits_all: bool,
pub embedding: bool,
pub offload_kqv: bool,
pub cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
pub cb_eval_user_data: *mut std::ffi::c_void,
pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
}

impl LlamaContextParams {
/// Set the seed of the context
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// let params = params.with_seed(1234);
/// assert_eq!(params.seed(), 1234);
/// ```
pub fn with_seed(mut self, seed: u32) -> Self {
self.context_params.seed = seed;
self
}

/// Get the seed of the context
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_seed(1234);
/// assert_eq!(params.seed(), 1234);
/// ```
pub fn seed(&self) -> u32 {
self.context_params.seed
}

/// Set the side of the context
///
/// # Examples
///
/// ```rust
/// # use std::num::NonZeroU32;
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// let params = params.with_n_ctx(NonZeroU32::new(2048));
/// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
/// ```
pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
self.context_params.n_ctx = n_ctx.map_or(0, |n_ctx| n_ctx.get());
self
}

/// Get the size of the context.
///
/// [`None`] if the context size is specified by the model and not the context.
///
/// # Examples
///
/// ```rust
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
pub fn n_ctx(&self) -> Option<NonZeroU32> {
NonZeroU32::new(self.context_params.n_ctx)
}

/// Get the type of rope scaling.
///
/// # Examples
///
/// ```rust
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
/// ```
pub fn rope_scaling_type(&self) -> RopeScalingType {
RopeScalingType::from(self.context_params.rope_scaling_type)
}
}

/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
/// ```
/// # use std::num::NonZeroU32;
/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
/// let params = LlamaContextParams::default();
/// assert_eq!(params.n_ctx, NonZeroU32::new(512), "n_ctx should be 512");
/// assert_eq!(params.rope_scaling_type, RopeScalingType::Unspecified);
/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
/// ```
impl Default for LlamaContextParams {
fn default() -> Self {
Self::from(unsafe { llama_cpp_sys_2::llama_context_default_params() })
}
}

impl From<llama_context_params> for LlamaContextParams {
fn from(
llama_context_params {
seed,
n_ctx,
n_batch,
n_threads,
n_threads_batch,
rope_freq_base,
rope_freq_scale,
cb_eval,
cb_eval_user_data,
type_k,
type_v,
mul_mat_q,
logits_all,
embedding,
rope_scaling_type,
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
}: llama_context_params,
) -> Self {
Self {
seed,
n_ctx: NonZeroU32::new(n_ctx),
n_batch,
n_threads,
n_threads_batch,
rope_freq_base,
rope_freq_scale,
type_k,
type_v,
mul_mat_q,
logits_all,
embedding,
rope_scaling_type: RopeScalingType::from(rope_scaling_type),
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}
let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
Self { context_params, }
}
}

impl From<LlamaContextParams> for llama_context_params {
fn from(
LlamaContextParams {
seed,
n_ctx,
n_batch,
n_threads,
n_threads_batch,
rope_freq_base,
rope_freq_scale,
type_k,
type_v,
mul_mat_q,
logits_all,
embedding,
rope_scaling_type,
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}: LlamaContextParams,
) -> Self {
llama_context_params {
seed,
n_ctx: n_ctx.map_or(0, NonZeroU32::get),
n_batch,
n_threads,
n_threads_batch,
rope_freq_base,
rope_freq_scale,
type_k,
type_v,
mul_mat_q,
logits_all,
embedding,
rope_scaling_type: i8::from(rope_scaling_type),
yarn_ext_factor,
yarn_attn_factor,
yarn_beta_fast,
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}
}
}
11 changes: 5 additions & 6 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::LlamaTokenType;
use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError};
use llama_cpp_sys_2::{llama_context_params, llama_token_get_type, llama_vocab_type};
use std::ffi::CString;
use std::os::raw::c_int;
use std::path::Path;
Expand Down Expand Up @@ -184,7 +183,7 @@ impl LlamaModel {
/// If the token type is not known to this library.
#[must_use]
pub fn token_type(&self, LlamaToken(id): LlamaToken) -> LlamaTokenType {
let token_type = unsafe { llama_token_get_type(self.model.as_ptr(), id) };
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_type(self.model.as_ptr(), id) };
LlamaTokenType::try_from(token_type).expect("token type is valid")
}

Expand Down Expand Up @@ -314,7 +313,7 @@ impl LlamaModel {
_: &LlamaBackend,
params: LlamaContextParams,
) -> Result<LlamaContext, LlamaContextLoadError> {
let context_params = llama_context_params::from(params);
let context_params = params.context_params;
let context = unsafe {
llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
};
Expand Down Expand Up @@ -345,13 +344,13 @@ pub enum VocabType {
pub enum LlamaTokenTypeFromIntError {
/// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
#[error("Unknown Value {0}")]
UnknownValue(llama_vocab_type),
UnknownValue(llama_cpp_sys_2::llama_vocab_type),
}

impl TryFrom<llama_vocab_type> for VocabType {
impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
type Error = LlamaTokenTypeFromIntError;

fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
match value {
llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
Expand Down