Skip to content

Commit

Permalink
restricted threaded access, updated llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
pedro-devv committed Apr 3, 2024
1 parent 9dd2b22 commit 892368d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 38 deletions.
37 changes: 24 additions & 13 deletions crates/llama_cpp/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub enum LlamaTokenizationError {
///
/// This is a thin wrapper over an `Arc<*mut llama_model>`, which is used to share the
/// model across threads.
#[derive(Clone, Deref, DerefMut)]
#[derive(Deref, DerefMut)]
struct LlamaModelInner {
#[deref]
#[deref_mut]
Expand All @@ -77,8 +77,6 @@ struct LlamaModelInner {

unsafe impl Send for LlamaModelInner {}

unsafe impl Sync for LlamaModelInner {}

impl Drop for LlamaModelInner {
fn drop(&mut self) {
unsafe {
Expand All @@ -100,7 +98,7 @@ impl Drop for LlamaModelInner {
#[derive(Clone)]
pub struct LlamaModel {
/// A handle to the inner model on the other side of the C FFI boundary.
model: Arc<LlamaModelInner>,
model: Arc<Mutex<LlamaModelInner>>,

/// The size of this model's vocabulary, in tokens.
vocabulary_size: usize,
Expand Down Expand Up @@ -230,10 +228,10 @@ impl LlamaModel {
.unwrap_or(0);

Ok(Self {
model: Arc::new(LlamaModelInner {
model: Arc::new(Mutex::new(LlamaModelInner {
model,
_backend_ref: backend_ref,
}),
})),
vocabulary_size: vocabulary_size as usize,
bos_token: Token(unsafe { llama_token_bos(model) }),
eos_token: Token(unsafe { llama_token_eos(model) }),
Expand Down Expand Up @@ -293,14 +291,16 @@ impl LlamaModel {
let mut out_buf = Vec::with_capacity(content.len() + 2);

let n_written_tokens = unsafe {
let model_lock = self.model.lock().unwrap();

// SAFETY: The pointer ranges specified here are always valid, and `n_written_tokens`
// is always less than `content.len()`.
//
// `content.len()` always fits within an `i32`.
//
// `out_buf` is a `Vec<Token>`, and `Token` is `#[repr(transparent)]` over an `i32`.
llama_tokenize(
**self.model,
**model_lock,
content.as_ptr() as *const i8,
content.len() as i32,
out_buf.as_mut_ptr() as *mut i32,
Expand Down Expand Up @@ -356,7 +356,11 @@ impl LlamaModel {
token.0
);

unsafe { CStr::from_ptr(llama_token_get_text(**self.model, token.0)) }.to_bytes()
unsafe {
let model_lock = self.model.lock().unwrap();
CStr::from_ptr(llama_token_get_text(**model_lock, token.0))
}
.to_bytes()
}

/// Converts the provided token into a `Vec<u8>` piece, using the model's vocabulary.
Expand All @@ -365,11 +369,12 @@ impl LlamaModel {
pub fn token_to_byte_piece(&self, token: Token) -> Vec<u8> {
let initial_size = 8u16;
let mut buffer = vec![0u8; usize::from(initial_size)];
let model_lock = self.model.lock().unwrap();
let size = unsafe {
// SAFETY: Casting `*mut u8` to `*mut i8` is safe because `u8` and
// `i8` have the same size and alignment.
llama_token_to_piece(
**self.model,
**model_lock,
token.0,
buffer.as_mut_ptr() as *mut i8,
std::os::raw::c_int::from(initial_size),
Expand All @@ -383,7 +388,7 @@ impl LlamaModel {
// and `i8` have the same size and alignment. The length of
// buffer is accurate for this reason.
llama_token_to_piece(
**self.model,
**model_lock,
token.0,
buffer.as_mut_ptr() as *mut i8,
std::os::raw::c_int::from(buffer.len() as i32),
Expand Down Expand Up @@ -421,11 +426,13 @@ impl LlamaModel {
let token_buf = &mut buf[i..];

let size = unsafe {
let model_lock = self.model.lock().unwrap();

// SAFETY: Casting `*mut u8` to `*mut i8` is safe because `u8` and
// `i8` have the same size and alignment. The length of token_buf is
// accurate for this reason.
llama_cpp_sys::llama_token_to_piece(
**self.model,
**model_lock,
t.0,
token_buf.as_mut_ptr() as *mut i8,
token_buf.len() as i32,
Expand Down Expand Up @@ -463,9 +470,11 @@ impl LlamaModel {
let max_batch = params.n_batch;

let ctx = unsafe {
let model_lock = self.model.lock().unwrap();

// SAFETY: due to `_model` being declared in the `LlamaContext`, `self` must live
// for at least the lifetime of `LlamaContext`.
llama_new_context_with_model(**self.model, params)
llama_new_context_with_model(**model_lock, params)
};
if ctx.is_null() {
return Err(LlamaContextError::SessionFailed);
Expand Down Expand Up @@ -640,9 +649,11 @@ impl LlamaModel {

let context_params = params.as_context_params(batch_capacity);
let context = unsafe {
let model_lock = self.model.lock().unwrap();

// SAFETY: due to `_model` being declared in the `LlamaContext`, `self` must live
// for at least the lifetime of `LlamaContext`.
llama_new_context_with_model(**self.model, context_params)
llama_new_context_with_model(**model_lock, context_params)
};

if context.is_null() {
Expand Down
36 changes: 20 additions & 16 deletions crates/llama_cpp/src/session/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Functionality for the [`LlamaSession`] struct
use derive_more::{Deref, DerefMut};
use std::cmp::min;
use std::ops::{Bound, RangeBounds};
use std::sync::atomic::{AtomicUsize, Ordering};
Expand Down Expand Up @@ -30,15 +31,14 @@ pub use params::*;
/// The inner part of a [`LlamaSession`].
///
/// This is wrapped in an `Arc` for sharing across thread boundaries.
#[derive(Deref, DerefMut)]
pub(crate) struct LlamaContextInner {
/// A pointer to the inner context.
pub(crate) ptr: *mut llama_context,
}

unsafe impl Send for LlamaContextInner {}

unsafe impl Sync for LlamaContextInner {}

impl Drop for LlamaContextInner {
fn drop(&mut self) {
// SAFETY: `drop`ping more than once is unsound [1], so `self.model` cannot have been
Expand Down Expand Up @@ -173,9 +173,11 @@ impl LlamaSession {
trace!("Starting LLaMA decode for batch");

let err = unsafe {
let session_guard = self.inner.ctx.lock().unwrap();

// SAFETY: `llama_decode` will not fail for a valid `batch`, which we correctly
// initialized above.
llama_decode(self.inner.ctx.lock().unwrap().ptr, batch.handle())
llama_decode(**session_guard, batch.handle())
};
if err != 0 {
return Err(LlamaContextError::DecodeFailed(err));
Expand Down Expand Up @@ -281,12 +283,12 @@ impl LlamaSession {
if session.inner.last_batch_size.load(Ordering::SeqCst) == 0 {
// Remove last token
unsafe {
llama_kv_cache_seq_rm(context.ptr, -1, token_buf.len() as i32 - 1, -1);
llama_kv_cache_seq_rm(**context, -1, token_buf.len() as i32 - 1, -1);
}

// Decode last token
batch.add(*token_buf.last().unwrap(), current_pos, &[0], true);
let res = unsafe { llama_decode(context.ptr, batch.handle()) };
let res = unsafe { llama_decode(**context, batch.handle()) };

if res != 0 {
error!("Failed to decode context ({res})");
Expand All @@ -305,7 +307,7 @@ impl LlamaSession {
// Get logit values from the model and store them in a `llama_token_data_array`
let mut candidates: Vec<llama_token_data> = {
let i = session.inner.last_batch_size.load(Ordering::SeqCst);
let logits = unsafe { llama_get_logits_ith(context.ptr, (i - 1) as i32) };
let logits = unsafe { llama_get_logits_ith(**context, (i - 1) as i32) };
let logits = unsafe { std::slice::from_raw_parts(logits, vocab) };

logits
Expand All @@ -326,7 +328,7 @@ impl LlamaSession {
};

// Select the next token
let token = sampler.sample(context.ptr, &token_buf, candidates_p);
let token = sampler.sample(**context, &token_buf, candidates_p);

// Send the token to the `CompletionHandle`, exiting on failure
if let Err(e) = tx.send(token) {
Expand All @@ -342,7 +344,7 @@ impl LlamaSession {

// Create a batch with the generated token and decode it
batch.add(token, current_pos, &[0], true);
let res = unsafe { llama_decode(context.ptr, batch.handle()) };
let res = unsafe { llama_decode(**context, batch.handle()) };

if res != 0 {
error!("Failed to decode context ({res})");
Expand Down Expand Up @@ -408,10 +410,12 @@ impl LlamaSession {
Bound::Unbounded => -1,
};

let context = self.inner.ctx.lock().unwrap();

// -1 here to match all sequences
let success = unsafe { llama_kv_cache_seq_rm(context.ptr, -1, start_bound, end_bound) };
let success = unsafe {
let context = self.inner.ctx.lock().unwrap();

llama_kv_cache_seq_rm(**context, -1, start_bound, end_bound)
};

if !success {
return Err(LlamaContextError::InvalidRange);
Expand Down Expand Up @@ -511,18 +515,18 @@ impl LlamaSession {
#[allow(unused_mut)]
let mut copy = self.model().create_session(self.inner.params.clone())?;

let size = unsafe { llama_get_state_size(ctx.ptr) };
let size = unsafe { llama_get_state_size(**ctx) };
let mut buf = vec![0; size];

// SAFETY: `llama_copy_state_data` and `llama_set_state_data` should never write/read more than
// `llama_get_state_size` bytes, so `buf` should be big enough.
//
// `copy` was created from the same model as `self` and with the same parameters.
unsafe {
let copy_size = llama_copy_state_data(ctx.ptr, buf.as_mut_ptr());
let copy_size = llama_copy_state_data(**ctx, buf.as_mut_ptr());
assert!(copy_size <= size);
let set_size =
llama_set_state_data(copy.inner.ctx.lock().unwrap().ptr, buf.as_mut_ptr());
let copy_guard = copy.inner.ctx.lock().unwrap();
let set_size = llama_set_state_data(**copy_guard, buf.as_mut_ptr());
assert_eq!(copy_size, set_size);
}

Expand All @@ -542,6 +546,6 @@ impl LlamaSession {
/// Currently there is no way to check the amount of memory occupied in devices.
pub fn memory_size(&self) -> usize {
let ctx = self.inner.ctx.lock().unwrap();
unsafe { llama_get_state_size(ctx.ptr) }
unsafe { llama_get_state_size(**ctx) }
}
}
21 changes: 15 additions & 6 deletions crates/llama_cpp_sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::env;
use std::fs::File;
use std::fs::{read_dir, File};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::process::Command;
Expand Down Expand Up @@ -424,8 +424,8 @@ fn compile_cuda(cx: &mut Build, cxx: &mut Build, featless_cxx: Build) -> &'stati

// CUDA gets linked through the cudarc crate.

cx.define("GGML_USE_CUBLAS", None);
cxx.define("GGML_USE_CUBLAS", None);
cx.define("GGML_USE_CUDA", None);
cxx.define("GGML_USE_CUDA", None);

let mut nvcc = featless_cxx;
nvcc.cuda(true)
Expand Down Expand Up @@ -453,9 +453,17 @@ fn compile_cuda(cx: &mut Build, cxx: &mut Build, featless_cxx: Build) -> &'stati
}

let lib_name = "ggml-cuda";

nvcc.file(LLAMA_PATH.join("ggml-cuda.cu"))
.include(LLAMA_PATH.join("ggml-cuda.h"))
let cuda_path = LLAMA_PATH.join("ggml-cuda");
let cuda_sources = read_dir(cuda_path.as_path())
.unwrap()
.map(|f| f.unwrap())
.filter(|entry| entry.file_name().to_string_lossy().ends_with(".cu"))
.map(|entry| entry.path());

nvcc.include(cuda_path.as_path())
.include(LLAMA_PATH.as_path())
.files(cuda_sources)
.file(LLAMA_PATH.join("ggml-cuda.cu"))
.compile(lib_name);

lib_name
Expand Down Expand Up @@ -579,6 +587,7 @@ fn compile_llama(mut cxx: Build, _out_path: impl AsRef<Path>) {
println!("Compiling Llama.cpp..");
cxx.include(LLAMA_PATH.as_path())
.file(LLAMA_PATH.join("unicode.cpp"))
.file(LLAMA_PATH.join("unicode-data.cpp"))
.file(LLAMA_PATH.join("llama.cpp"))
.compile("llama");
}
Expand Down
4 changes: 2 additions & 2 deletions crates/llama_cpp_sys/include/build-info.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#ifndef BUILD_INFO_H
#define BUILD_INFO_H

#define BUILD_NUMBER 2465
#define BUILD_COMMIT "d0d5de4"
#define BUILD_NUMBER 2589
#define BUILD_COMMIT "60cdf40"

#endif // BUILD_INFO_H
2 changes: 1 addition & 1 deletion crates/llama_cpp_sys/thirdparty/llama.cpp
Submodule llama.cpp updated 199 files

0 comments on commit 892368d

Please sign in to comment.