Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using across threads #483

Open
thewh1teagle opened this issue Aug 30, 2024 · 11 comments
Open

Using across threads #483

thewh1teagle opened this issue Aug 30, 2024 · 11 comments
Labels

Comments

@thewh1teagle
Copy link
Contributor

thewh1teagle commented Aug 30, 2024

I don't able to use llama across threads. How can I wrap it in mutex? Can we avoid using lifetimes in

pub struct LlamaContext<'a> {

It makes it harder.

Simple example I used: #482

@MarcusDunn
Copy link
Contributor

MarcusDunn commented Aug 30, 2024

The lifetimes are an attempt to capture some of the invariants in the C++ code. Use after free of the model would be trivial without them (as in the c++ the context has a pointer to the model), and creating more than one context per model would be impossible if LlamaContext owned the model.

Improvements are welcome but keep in mind this library is built to be about as close to raw bindings as you can get while being safe to use. I have not used it, but https://github.com/edgenai/llama_cpp-rs is a much higher-level set of bindings that may be more useful to you. (maintaining such bindings given llama.cpp's blistering development pace was not something I thought I could do)

@thewh1teagle
Copy link
Contributor Author

thewh1teagle commented Aug 30, 2024

but https://github.com/edgenai/llama_cpp-rs is a much higher-level set of bindings that may be more useful to you

I actually like that this library is direct bindings.
I also use whisper-rs but there I don't have the lifetimes so it's easier to wrap it in Mutex and use across threads.
It sounds good to have lifetime for safety but I don't know then how to use it.
Can you show me example how to wrap it in static Mutex variable?
I want to use it across threads with mutex. I notice that then I need to delcare the lifetime everywhere just to use it.

Maybe it's not that simple, how about disable that lifetime? The user should know that he shouldn't create multiple contexts as you said it's direct bindings.

@MarcusDunn
Copy link
Contributor

MarcusDunn commented Aug 30, 2024

Here's some snippets from our async web server.

static LLAMA_BACKEND: tokio::sync::OnceCell<llama_cpp_2::llama_backend::LlamaBackend> =
    tokio::sync::OnceCell::const_new();
pub(crate) static LLAMA_MODEL: tokio::sync::OnceCell<llama_cpp_2::model::LlamaModel> =
    tokio::sync::OnceCell::const_new();
    let backend = llama_cpp_2::llama_backend::LlamaBackend::init()?;
    let model = llama_cpp_2::model::LlamaModel::load_from_file(&backend, gguf_path, llama_model_params)?;

    LLAMA_BACKEND
        .set(backend)
        .map_err(|_| LlamaInitError::BackendAlreadyInitialized)?;
    LLAMA_MODEL
        .set(model)
        .map_err(|_| LlamaInitError::ModelAlreadyInitialized)?;

The context is then created on a "worker" thread and requests are passed to it via a channel of GenerateRequest's.

each message is basically

pub struct GenerateRequest {
    pub prompt: Vec<llama_cpp_2::token::LlamaToken>,
    pub sender: tokio::sync::mpsc::Sender<llama_cpp_2::token::LlamaToken>,
}

@MarcusDunn
Copy link
Contributor

MarcusDunn commented Aug 30, 2024

you can also opt-out of the !Send and !Sync using a new type + obtain a static lifetime using the above (tokio not required)

@MarcusDunn
Copy link
Contributor

MarcusDunn commented Aug 30, 2024

The user should know that he shouldn't create multiple contexts as you said it's direct bindings.

The user can create multiple contexts, that's no problem. They can share the same model as well.

@thewh1teagle
Copy link
Contributor Author

thewh1teagle commented Aug 30, 2024

Here's some snippets from our async web server.

Thanks! this one worked for me
I tried with static context too... that's why it didn't worked for me.

/*
wget https://huggingface.co/Qwen/Qwen2-1.5B-Instruct-GGUF/resolve/main/qwen2-1_5b-instruct-q4_0.gguf
cargo run qwen2-1_5b-instruct-q4_0.gguf
*/
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use once_cell::sync::Lazy;
use std::sync::Mutex;

static MODEL: Lazy<Mutex<Option<LlamaModel>>> = Lazy::new(|| Mutex::new(None));
static BACKEND: Lazy<Mutex<Option<LlamaBackend>>> = Lazy::new(|| Mutex::new(None));

fn get_answer(prompt: String) -> String {
    let mut model = MODEL.lock().unwrap();
    let model = model.as_mut().unwrap();
    let mut backend = BACKEND.lock().unwrap();
    let backend = backend.as_mut().unwrap();
    let prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt).to_string();
    let ctx_params = LlamaContextParams::default();
    let mut ctx = model
        .new_context(&backend, ctx_params)
        .expect("unable to create the llama_context");
    let tokens_list = model
        .str_to_token(&prompt, AddBos::Always)
        .expect(&format!("failed to tokenize {prompt}"));
    let n_len = 64;
       // create a llama_batch with size 512
    // we use this object to submit token data for decoding
    let mut batch = LlamaBatch::new(512, 1);

    let last_index: i32 = (tokens_list.len() - 1) as i32;
    for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
        // llama_decode will output logits only for the last token of the prompt
        let is_last = i == last_index;
        batch.add(token, i, &[0], is_last).unwrap();
    }
    ctx.decode(&mut batch).expect("llama_decode() failed");


    let mut n_cur = batch.n_tokens();


    // The `Decoder`
    let mut decoder = encoding_rs::UTF_8.new_decoder();


    let mut text = String::new();
    while n_cur <= n_len {
        // sample the next token
        {
            let candidates = ctx.candidates_ith(batch.n_tokens() - 1);

            let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);

            // sample the most likely token
            let new_token_id = ctx.sample_token_greedy(candidates_p);

            // is it an end of stream?
            if new_token_id == model.token_eos() {
                eprintln!();
                break;
            }

            let output_bytes = model.token_to_bytes(new_token_id, Special::Tokenize).unwrap();
            // use `Decoder.decode_to_string()` to avoid the intermediate buffer
            let mut output_string = String::with_capacity(32);
            let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
            text.push_str(&output_string);

            batch.clear();
            batch.add(new_token_id, n_cur, &[0], true).unwrap();
        }

        n_cur += 1;

        ctx.decode(&mut batch).expect("failed to eval");
    }
    text
}

fn create_model(path: String) -> (LlamaModel, LlamaBackend) {
    let backend = LlamaBackend::init().unwrap();
    let params = LlamaModelParams::default();
    LlamaContextParams::default();
    let model =
        LlamaModel::load_from_file(&backend, path, &params).expect("unable to load model");
    return (model, backend)

}

fn main() {
    let model_path = std::env::args().nth(1).expect("Please specify model path");
    let (model, backend) = create_model(model_path);
    *MODEL.lock().unwrap() = Some(model);
    *BACKEND.lock().unwrap() = Some(backend);

    println!("Answer: {}", get_answer("Hello! How are you?".into()));
    println!("Answer: {}", get_answer("What's the best food you know?".into()));
}

Is it efficient / normal to create LlamaContextParams and LlamaContext each time instead of static one?

@MarcusDunn
Copy link
Contributor

It's certainly not efficient. LlamaContext is pretty large and takes some time to initialize. I would try to figure out how to share it as much as possible (a static + mutex would likely be fine), but remember that the most efficient way to run many queries would be using batching, so a more clever synchronization is required if you need max performance.

@MarcusDunn
Copy link
Contributor

I'd take a look at llama.cpp's server implementation for a good idea of squeezing as much parallel performance out of the library as possible.

@thewh1teagle
Copy link
Contributor Author

thewh1teagle commented Aug 31, 2024

I'd take a look at llama.cpp's server implementation for a good idea of squeezing as much parallel performance out of the library as possible.

Thanks. I think that as starting point we simply need a way to share the context with mutex in the example I attached. The best in my opinion is to remove the lifetime or somehow keep it behind feature flag that enabled by default and then clone the static handle.

@MarcusDunn
Copy link
Contributor

MarcusDunn commented Aug 31, 2024

Nothing is stopping you from putting the context in a static mutex with a newtype wrapper with Send (I'm unsure of the safety of this). You're leaving some performance on the table as mutual exclusion does not allow parallel decoding.

struct MyContext(
   context: LlamaContext<'static>
)

unsafe impl Send for MyContext {}

static context: OnceLock<Mutex<MyContext>> = ...

@thewh1teagle
Copy link
Contributor Author

I tried to add context to the struct here but once I create context it borrow the model.
Can you show me example with it?
Let me know if you also want a PR eventually with thread safe example.

/*
wget https://huggingface.co/Qwen/Qwen2-1.5B-Instruct-GGUF/resolve/main/qwen2-1_5b-instruct-q4_0.gguf
cargo run qwen2-1_5b-instruct-q4_0.gguf
*/
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use once_cell::sync::Lazy;
use std::sync::Mutex;

static LLAMA: Lazy<Mutex<Option<Llama>>> = Lazy::new(|| Mutex::new(None));

struct Llama {
    backend: LlamaBackend,
    model: LlamaModel,
}

impl Llama {
    pub fn new(path: &str) -> Self {
        let backend = LlamaBackend::init().unwrap();
        let params = LlamaModelParams::default();
        LlamaContextParams::default();
        let model =
        LlamaModel::load_from_file(&backend, path, &params).expect("unable to load model");
        Llama{backend, model}
    }

    pub fn ask(&mut self, prompt: &str) -> String {        
        let prompt = format!("<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", prompt).to_string();
        let ctx_params = LlamaContextParams::default();
        let mut ctx = self.model
            .new_context(&self.backend, ctx_params)
            .expect("unable to create the llama_context");
        let tokens_list = self.model
            .str_to_token(&prompt, AddBos::Always)
            .expect(&format!("failed to tokenize {prompt}"));
        let n_len = 64;
           // create a llama_batch with size 512
        // we use this object to submit token data for decoding
        let mut batch = LlamaBatch::new(512, 1);
    
        let last_index: i32 = (tokens_list.len() - 1) as i32;
        for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
            // llama_decode will output logits only for the last token of the prompt
            let is_last = i == last_index;
            batch.add(token, i, &[0], is_last).unwrap();
        }
        ctx.decode(&mut batch).expect("llama_decode() failed");
    
    
        let mut n_cur = batch.n_tokens();
    
    
        // The `Decoder`
        let mut decoder = encoding_rs::UTF_8.new_decoder();
    
    
        let mut text = String::new();
        while n_cur <= n_len {
            // sample the next token
            {
                let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
    
                let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
    
                // sample the most likely token
                let new_token_id = ctx.sample_token_greedy(candidates_p);
    
                // is it an end of stream?
                if new_token_id == self.model.token_eos() {
                    eprintln!();
                    break;
                }
    
                let output_bytes = self.model.token_to_bytes(new_token_id, Special::Tokenize).unwrap();
                // use `Decoder.decode_to_string()` to avoid the intermediate buffer
                let mut output_string = String::with_capacity(32);
                let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
                text.push_str(&output_string);
    
                batch.clear();
                batch.add(new_token_id, n_cur, &[0], true).unwrap();
            }
    
            n_cur += 1;
    
            ctx.decode(&mut batch).expect("failed to eval");
        }
        text
    }
}

fn main() {
    let model_path = std::env::args().nth(1).expect("Please specify model path");
    let llama = Llama::new(&model_path);

    // Thread safe example
    *LLAMA.lock().unwrap() = Some(llama);

    // ... Take it back
    let mut llama = LLAMA.lock().unwrap();
    let llama = llama.as_mut().unwrap();
    let answer = llama.ask("How are you?");
    println!("answer: {}", answer);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants