-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using across threads #483
Comments
The lifetimes are an attempt to capture some of the invariants in the C++ code. Use after free of the model would be trivial without them (as in the c++ the context has a pointer to the model), and creating more than one context per model would be impossible if Improvements are welcome but keep in mind this library is built to be about as close to raw bindings as you can get while being safe to use. I have not used it, but https://github.com/edgenai/llama_cpp-rs is a much higher-level set of bindings that may be more useful to you. (maintaining such bindings given llama.cpp's blistering development pace was not something I thought I could do) |
I actually like that this library is direct bindings. Maybe it's not that simple, how about disable that lifetime? The user should know that he shouldn't create multiple contexts as you said it's direct bindings. |
Here's some snippets from our async web server. static LLAMA_BACKEND: tokio::sync::OnceCell<llama_cpp_2::llama_backend::LlamaBackend> =
tokio::sync::OnceCell::const_new();
pub(crate) static LLAMA_MODEL: tokio::sync::OnceCell<llama_cpp_2::model::LlamaModel> =
tokio::sync::OnceCell::const_new(); let backend = llama_cpp_2::llama_backend::LlamaBackend::init()?;
let model = llama_cpp_2::model::LlamaModel::load_from_file(&backend, gguf_path, llama_model_params)?;
LLAMA_BACKEND
.set(backend)
.map_err(|_| LlamaInitError::BackendAlreadyInitialized)?;
LLAMA_MODEL
.set(model)
.map_err(|_| LlamaInitError::ModelAlreadyInitialized)?; The context is then created on a "worker" thread and requests are passed to it via a channel of each message is basically pub struct GenerateRequest {
pub prompt: Vec<llama_cpp_2::token::LlamaToken>,
pub sender: tokio::sync::mpsc::Sender<llama_cpp_2::token::LlamaToken>,
} |
you can also opt-out of the !Send and !Sync using a new type + obtain a static lifetime using the above (tokio not required) |
The user can create multiple contexts, that's no problem. They can share the same model as well. |
Thanks! this one worked for me /*
wget https://huggingface.co/Qwen/Qwen2-1.5B-Instruct-GGUF/resolve/main/qwen2-1_5b-instruct-q4_0.gguf
cargo run qwen2-1_5b-instruct-q4_0.gguf
*/
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use once_cell::sync::Lazy;
use std::sync::Mutex;
static MODEL: Lazy<Mutex<Option<LlamaModel>>> = Lazy::new(|| Mutex::new(None));
static BACKEND: Lazy<Mutex<Option<LlamaBackend>>> = Lazy::new(|| Mutex::new(None));
fn get_answer(prompt: String) -> String {
let mut model = MODEL.lock().unwrap();
let model = model.as_mut().unwrap();
let mut backend = BACKEND.lock().unwrap();
let backend = backend.as_mut().unwrap();
let prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt).to_string();
let ctx_params = LlamaContextParams::default();
let mut ctx = model
.new_context(&backend, ctx_params)
.expect("unable to create the llama_context");
let tokens_list = model
.str_to_token(&prompt, AddBos::Always)
.expect(&format!("failed to tokenize {prompt}"));
let n_len = 64;
// create a llama_batch with size 512
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(512, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
// llama_decode will output logits only for the last token of the prompt
let is_last = i == last_index;
batch.add(token, i, &[0], is_last).unwrap();
}
ctx.decode(&mut batch).expect("llama_decode() failed");
let mut n_cur = batch.n_tokens();
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut text = String::new();
while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);
// is it an end of stream?
if new_token_id == model.token_eos() {
eprintln!();
break;
}
let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize).unwrap();
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
text.push_str(&output_string);
batch.clear();
batch.add(new_token_id, n_cur, &[0], true).unwrap();
}
n_cur += 1;
ctx.decode(&mut batch).expect("failed to eval");
}
text
}
fn create_model(path: String) -> (LlamaModel, LlamaBackend) {
let backend = LlamaBackend::init().unwrap();
let params = LlamaModelParams::default();
LlamaContextParams::default();
let model =
LlamaModel::load_from_file(&backend, path, ¶ms).expect("unable to load model");
return (model, backend)
}
fn main() {
let model_path = std::env::args().nth(1).expect("Please specify model path");
let (model, backend) = create_model(model_path);
*MODEL.lock().unwrap() = Some(model);
*BACKEND.lock().unwrap() = Some(backend);
println!("Answer: {}", get_answer("Hello! How are you?".into()));
println!("Answer: {}", get_answer("What's the best food you know?".into()));
} Is it efficient / normal to create |
It's certainly not efficient. LlamaContext is pretty large and takes some time to initialize. I would try to figure out how to share it as much as possible (a static + mutex would likely be fine), but remember that the most efficient way to run many queries would be using batching, so a more clever synchronization is required if you need max performance. |
I'd take a look at llama.cpp's server implementation for a good idea of squeezing as much parallel performance out of the library as possible. |
Thanks. I think that as starting point we simply need a way to share the context with mutex in the example I attached. The best in my opinion is to remove the lifetime or somehow keep it behind feature flag that enabled by default and then clone the static handle. |
Nothing is stopping you from putting the context in a static mutex with a newtype wrapper with struct MyContext(
context: LlamaContext<'static>
)
unsafe impl Send for MyContext {}
static context: OnceLock<Mutex<MyContext>> = ... |
I tried to add context to the struct here but once I create context it borrow the model. /*
wget https://huggingface.co/Qwen/Qwen2-1.5B-Instruct-GGUF/resolve/main/qwen2-1_5b-instruct-q4_0.gguf
cargo run qwen2-1_5b-instruct-q4_0.gguf
*/
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use once_cell::sync::Lazy;
use std::sync::Mutex;
static LLAMA: Lazy<Mutex<Option<Llama>>> = Lazy::new(|| Mutex::new(None));
struct Llama {
backend: LlamaBackend,
model: LlamaModel,
}
impl Llama {
pub fn new(path: &str) -> Self {
let backend = LlamaBackend::init().unwrap();
let params = LlamaModelParams::default();
LlamaContextParams::default();
let model =
LlamaModel::load_from_file(&backend, path, ¶ms).expect("unable to load model");
Llama{backend, model}
}
pub fn ask(&mut self, prompt: &str) -> String {
let prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt).to_string();
let ctx_params = LlamaContextParams::default();
let mut ctx = self.model
.new_context(&self.backend, ctx_params)
.expect("unable to create the llama_context");
let tokens_list = self.model
.str_to_token(&prompt, AddBos::Always)
.expect(&format!("failed to tokenize {prompt}"));
let n_len = 64;
// create a llama_batch with size 512
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(512, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
// llama_decode will output logits only for the last token of the prompt
let is_last = i == last_index;
batch.add(token, i, &[0], is_last).unwrap();
}
ctx.decode(&mut batch).expect("llama_decode() failed");
let mut n_cur = batch.n_tokens();
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut text = String::new();
while n_cur <= n_len {
// sample the next token
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
// sample the most likely token
let new_token_id = ctx.sample_token_greedy(candidates_p);
// is it an end of stream?
if new_token_id == self.model.token_eos() {
eprintln!();
break;
}
let output_bytes = self.model.token_to_bytes(new_token_id, Special::Tokenize).unwrap();
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
text.push_str(&output_string);
batch.clear();
batch.add(new_token_id, n_cur, &[0], true).unwrap();
}
n_cur += 1;
ctx.decode(&mut batch).expect("failed to eval");
}
text
}
}
fn main() {
let model_path = std::env::args().nth(1).expect("Please specify model path");
let llama = Llama::new(&model_path);
// Thread safe example
*LLAMA.lock().unwrap() = Some(llama);
// ... Take it back
let mut llama = LLAMA.lock().unwrap();
let llama = llama.as_mut().unwrap();
let answer = llama.ask("How are you?");
println!("answer: {}", answer);
} |
I don't able to use llama across threads. How can I wrap it in mutex? Can we avoid using lifetimes in
llama-cpp-rs/llama-cpp-2/src/context.rs
Line 25 in 071598c
It makes it harder.
Simple example I used: #482
The text was updated successfully, but these errors were encountered: