Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated llama.cpp #68

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llama-cpp-2/benches/grammar_bias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ fn criterion_benchmark(c: &mut Criterion) {
.unwrap();
let backend = LlamaBackend::init().unwrap();
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(&backend, &file, &model_params).unwrap();
let model = LlamaModel::load_from_file(&backend, file, &model_params).unwrap();
let mut ctx = model
.new_context(&backend, &LlamaContextParams::default())
.new_context(&backend, LlamaContextParams::default())
.unwrap();
let grammar = LlamaGrammar::from_str(include_str!("../src/grammar/json.gbnf")).unwrap();

Expand Down
45 changes: 25 additions & 20 deletions llama-cpp-2/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
//! This is an translation of simple.cpp in llama.cpp using llama-cpp-2.
#![allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation, clippy::cast_precision_loss, clippy::cast_sign_loss)]

use std::io::Write;
use std::num::NonZeroU32;
use std::path::PathBuf;
use std::time::Duration;
use anyhow::{bail, Context, Result};
use clap::Parser;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::params::LlamaModelParams;
use anyhow::{bail, Context, Result};
use llama_cpp_2::ggml_time_us;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::AddBos;

use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use std::io::Write;
use std::num::NonZeroU32;
use std::path::PathBuf;
use std::time::Duration;

#[derive(clap::Parser)]
struct Args {
Expand All @@ -30,7 +29,6 @@ struct Args {
disable_gpu: bool,
}


fn main() -> Result<()> {
let params = Args::parse();

Expand Down Expand Up @@ -60,12 +58,14 @@ fn main() -> Result<()> {
.with_n_ctx(NonZeroU32::new(2048))
.with_seed(1234);

let mut ctx = model.new_context(&backend, ctx_params)
let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;

// tokenize the prompt

let tokens_list = model.str_to_token(&params.prompt, AddBos::Always)
let tokens_list = model
.str_to_token(&params.prompt, AddBos::Always)
.with_context(|| format!("failed to tokenize {}", params.prompt))?;

let n_cxt = ctx.n_ctx() as i32;
Expand All @@ -75,8 +75,10 @@ fn main() -> Result<()> {

// make sure the KV cache is big enough to hold all the prompt and generated tokens
if n_kv_req > n_cxt {
bail!("n_kv_req > n_ctx, the required kv cache size is not big enough
either reduce n_len or increase n_ctx")
bail!(
"n_kv_req > n_ctx, the required kv cache size is not big enough
either reduce n_len or increase n_ctx"
)
}

// print the prompt token-by-token
Expand Down Expand Up @@ -137,7 +139,6 @@ either reduce n_len or increase n_ctx")
ctx.decode(&mut batch).with_context(|| "failed to eval")?;

n_decode += 1;

}

eprintln!("\n");
Expand All @@ -146,10 +147,14 @@ either reduce n_len or increase n_ctx")

let duration = Duration::from_micros((t_main_end - t_main_start) as u64);

eprintln!("decoded {} tokens in {:.2} s, speed {:.2} t/s\n", n_decode, duration.as_secs_f32(), n_decode as f32 / duration.as_secs_f32());
eprintln!(
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
n_decode,
duration.as_secs_f32(),
n_decode as f32 / duration.as_secs_f32()
);

println!("{}", ctx.timings());

Ok(())

}
}
42 changes: 21 additions & 21 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ pub enum RopeScalingType {

/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
/// the value is not recognized.
impl From<i8> for RopeScalingType {
fn from(value: i8) -> Self {
impl From<i32> for RopeScalingType {
fn from(value: i32) -> Self {
match value {
0 => Self::None,
1 => Self::Linear,
Expand All @@ -31,7 +31,7 @@ impl From<i8> for RopeScalingType {
}

/// Create a `c_int` from a `RopeScalingType`.
impl From<RopeScalingType> for i8 {
impl From<RopeScalingType> for i32 {
fn from(value: RopeScalingType) -> Self {
match value {
RopeScalingType::None => 0,
Expand Down Expand Up @@ -84,7 +84,7 @@ impl LlamaContextParams {
/// let params = params.with_seed(1234);
/// assert_eq!(params.seed(), 1234);
/// ```
pub fn with_seed(mut self, seed: u32) -> Self {
#[must_use] pub fn with_seed(mut self, seed: u32) -> Self {
self.context_params.seed = seed;
self
}
Expand All @@ -99,7 +99,7 @@ impl LlamaContextParams {
/// .with_seed(1234);
/// assert_eq!(params.seed(), 1234);
/// ```
pub fn seed(&self) -> u32 {
#[must_use] pub fn seed(&self) -> u32 {
self.context_params.seed
}

Expand All @@ -114,8 +114,8 @@ impl LlamaContextParams {
/// let params = params.with_n_ctx(NonZeroU32::new(2048));
/// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
/// ```
pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
self.context_params.n_ctx = n_ctx.map_or(0, |n_ctx| n_ctx.get());
#[must_use] pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
self
}

Expand All @@ -128,11 +128,11 @@ impl LlamaContextParams {
/// ```rust
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
pub fn n_ctx(&self) -> Option<NonZeroU32> {
#[must_use] pub fn n_ctx(&self) -> Option<NonZeroU32> {
NonZeroU32::new(self.context_params.n_ctx)
}

/// Set the n_batch
/// Set the `n_batch`
///
/// # Examples
///
Expand All @@ -143,12 +143,12 @@ impl LlamaContextParams {
/// .with_n_batch(2048);
/// assert_eq!(params.n_batch(), 2048);
/// ```
pub fn with_n_batch(mut self, n_batch: u32) -> Self {
#[must_use] pub fn with_n_batch(mut self, n_batch: u32) -> Self {
self.context_params.n_batch = n_batch;
self
}

/// Get the n_batch
/// Get the `n_batch`
///
/// # Examples
///
Expand All @@ -157,7 +157,7 @@ impl LlamaContextParams {
/// let params = LlamaContextParams::default();
/// assert_eq!(params.n_batch(), 512);
/// ```
pub fn n_batch(&self) -> u32 {
#[must_use] pub fn n_batch(&self) -> u32 {
self.context_params.n_batch
}

Expand All @@ -171,8 +171,8 @@ impl LlamaContextParams {
/// .with_rope_scaling_type(RopeScalingType::Linear);
/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
/// ```
pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
self.context_params.rope_scaling_type = i8::from(rope_scaling_type);
#[must_use] pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
self
}

Expand All @@ -184,7 +184,7 @@ impl LlamaContextParams {
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
/// ```
pub fn rope_scaling_type(&self) -> RopeScalingType {
#[must_use] pub fn rope_scaling_type(&self) -> RopeScalingType {
RopeScalingType::from(self.context_params.rope_scaling_type)
}

Expand All @@ -198,7 +198,7 @@ impl LlamaContextParams {
/// .with_rope_freq_base(0.5);
/// assert_eq!(params.rope_freq_base(), 0.5);
/// ```
pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
#[must_use] pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
self.context_params.rope_freq_base = rope_freq_base;
self
}
Expand All @@ -211,7 +211,7 @@ impl LlamaContextParams {
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.rope_freq_base(), 0.0);
/// ```
pub fn rope_freq_base(&self) -> f32 {
#[must_use] pub fn rope_freq_base(&self) -> f32 {
self.context_params.rope_freq_base
}

Expand All @@ -225,7 +225,7 @@ impl LlamaContextParams {
/// .with_rope_freq_scale(0.5);
/// assert_eq!(params.rope_freq_scale(), 0.5);
/// ```
pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
#[must_use] pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
self.context_params.rope_freq_scale = rope_freq_scale;
self
}
Expand All @@ -238,7 +238,7 @@ impl LlamaContextParams {
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.rope_freq_scale(), 0.0);
/// ```
pub fn rope_freq_scale(&self) -> f32 {
#[must_use] pub fn rope_freq_scale(&self) -> f32 {
self.context_params.rope_freq_scale
}

Expand All @@ -250,7 +250,7 @@ impl LlamaContextParams {
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.n_threads(), 4);
/// ```
pub fn n_threads(&self) -> u32 {
#[must_use] pub fn n_threads(&self) -> u32 {
self.context_params.n_threads
}

Expand All @@ -264,7 +264,7 @@ impl LlamaContextParams {
/// .with_n_threads(8);
/// assert_eq!(params.n_threads(), 8);
/// ```
pub fn with_n_threads(mut self, n_threads: u32) -> Self {
#[must_use] pub fn with_n_threads(mut self, n_threads: u32) -> Self {
self.context_params.n_threads = n_threads;
self
}
Expand Down
21 changes: 16 additions & 5 deletions llama-cpp-2/src/llama_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,21 @@ impl LlamaBatch {
///
/// - [`self.llama_batch.n_tokens`] does not fit into a usize
/// - [`seq_ids.len()`] does not fit into a [`llama_seq_id`]
///
/// # Errors
///
/// returns a error if there is insufficient space in the buffer
pub fn add(
&mut self,
LlamaToken(id): LlamaToken,
pos: llama_pos,
seq_ids: &[i32],
logits: bool,
) -> Result<(), BatchAddError> {
if self.allocated < usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize") {
return Err(BatchAddError::InsufficientSpace(self.allocated))
if self.allocated
< usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize")
{
return Err(BatchAddError::InsufficientSpace(self.allocated));
}
let offset = self.llama_batch.n_tokens;
let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize");
Expand All @@ -55,8 +61,10 @@ impl LlamaBatch {
// batch.pos [batch.n_tokens] = pos,
self.llama_batch.pos.add(offset_usize).write(pos);
// batch.n_seq_id[batch.n_tokens] = seq_ids.size();
self.llama_batch.n_seq_id.add(offset_usize).write(llama_seq_id::try_from(seq_ids.len())
.expect("cannot fit seq_ids.len() into a llama_seq_id"));
self.llama_batch.n_seq_id.add(offset_usize).write(
llama_seq_id::try_from(seq_ids.len())
.expect("cannot fit seq_ids.len() into a llama_seq_id"),
);
// for (size_t i = 0; i < seq_ids.size(); ++i) {
// batch.seq_id[batch.n_tokens][i] = seq_ids[i];
// }
Expand All @@ -65,7 +73,10 @@ impl LlamaBatch {
tmp.add(i).write(*seq_id);
}
// batch.logits [batch.n_tokens] = logits;
self.llama_batch.logits.add(offset_usize).write(i8::from(logits));
self.llama_batch
.logits
.add(offset_usize)
.write(i8::from(logits));
}

if logits {
Expand Down
4 changes: 1 addition & 3 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl LlamaModel {
) -> Result<Vec<LlamaToken>, StringToTokenError> {
let add_bos = match add_bos {
AddBos::Always => true,
AddBos::Never => false
AddBos::Never => false,
};

let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
Expand All @@ -136,8 +136,6 @@ impl LlamaModel {
let buffer_capacity =
c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");



let size = unsafe {
llama_cpp_sys_2::llama_tokenize(
self.model.as_ptr(),
Expand Down
2 changes: 1 addition & 1 deletion llama-cpp-2/src/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod data_array;
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[allow(clippy::module_name_repetitions)]
pub struct LlamaToken( pub llama_cpp_sys_2::llama_token);
pub struct LlamaToken(pub llama_cpp_sys_2::llama_token);

impl Display for LlamaToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down
17 changes: 11 additions & 6 deletions llama-cpp-sys-2/build.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
use std::env;
use std::path::PathBuf;
use std::path::Path;
use std::path::PathBuf;

fn main() {
println!("cargo:rerun-if-changed=llama.cpp");

let cublas_enabled = env::var("CARGO_FEATURE_CUBLAS").is_ok();

if !Path::new("llama.cpp/ggml.c").exists() {
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
}

let mut ggml = cc::Build::new();
let mut ggml_cuda = if cublas_enabled { Some(cc::Build::new()) } else { None };
let mut ggml_cuda = if cublas_enabled {
Some(cc::Build::new())
} else {
None
};
let mut llama_cpp = cc::Build::new();

ggml.cpp(false);
llama_cpp.cpp(true);

// https://github.com/ggerganov/llama.cpp/blob/a836c8f534ab789b02da149fbdaf7735500bff74/Makefile#L364-L368
if let Some(ggml_cuda) = &mut ggml_cuda {
for lib in ["cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt"] {
for lib in [
"cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt",
] {
println!("cargo:rustc-link-lib={}", lib);
}

Expand Down Expand Up @@ -66,8 +72,7 @@ fn main() {
ggml.define("_GNU_SOURCE", None);
}

ggml
.std("c17")
ggml.std("c17")
.file("llama.cpp/ggml.c")
.file("llama.cpp/ggml-alloc.c")
.file("llama.cpp/ggml-backend.c")
Expand Down