-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
example to use grammar sampler ? #604
Comments
Here is how I do if anyone interested : pub struct LlmSampler {
grammar_sampler: Option<LlamaSampler>,
chain: LlamaSampler,
cur_p: Option<LlamaTokenDataArray>,
}
impl LlmSampler {
pub fn new(model: &LlamaModel, grammar: &str) -> Self {
// TODO: Implement a real sampler based on opts
let samplers = vec![LlamaSampler::dist(0)];
let sampler = LlamaSampler::chain(samplers, false);
let grammar_sampler = if !grammar.is_empty() {
Some(LlamaSampler::grammar(model, &grammar, "root"))
} else {
None
};
Self { grammar_sampler, chain: sampler, cur_p: None }
}
pub fn accept(&mut self, token: LlamaToken, with_grammar: bool) {
if with_grammar {
if let Some(s) = self.grammar_sampler.as_mut() {
s.accept(token);
}
}
self.chain.accept(token);
}
fn sample(&mut self, ctx: &LlamaContext, idx: i32) -> LlamaToken {
self.set_logits(ctx, idx);
{
let cur_p = self.cur_p.as_mut().unwrap();
self.chain.apply(cur_p);
let id = cur_p.data[cur_p.selected.unwrap()].id();
// check if the sampeld token fits the grammar
if let Some(grammar_sampler) = self.grammar_sampler.as_mut() {
let single_token_data = LlamaTokenData::new(id, 1.0, 0.0);
let mut signel_token_data_array =
LlamaTokenDataArray::new(vec![single_token_data], false);
grammar_sampler.apply(&mut signel_token_data_array);
let logit = signel_token_data_array.data[0].logit();
let is_valid = !(logit.is_infinite() && logit.is_sign_negative());
if is_valid {
return id;
}
}
}
// resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
self.set_logits(ctx, idx);
let cur_p = self.cur_p.as_mut().unwrap();
if let Some(grammar_sampler) = self.grammar_sampler.as_mut() {
grammar_sampler.apply(cur_p);
}
self.chain.apply(cur_p);
cur_p.data[cur_p.selected.unwrap()].id()
}
fn set_logits(&mut self, ctx: &LlamaContext, i: i32) {
let logits = ctx.get_logits_ith(i);
let n_vocab = ctx.model.n_vocab();
let mut cur = vec![];
for i in 0..n_vocab {
let token = LlamaToken(i);
cur.push(LlamaTokenData::new(token, logits[i as usize], 0.0));
}
self.cur_p = Some(LlamaTokenDataArray::new(cur, false));
}
} |
This should be possible with only: let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::grammar(grammar_str, "root"),
LlamaSampler::dist(seed),
]); As long as you make sure to use one of the |
For me that always results in a coredump:
I tested it with 0.1.84 and the main (9af1286) branch.
|
That is a bug in the simple example itself. Both |
Any point to use grammar sampler in llama-cpp-rs ?
In the origin llama.cpp server example, it uses grammar sampler and chain sampler separately.
Do I have to use these that way ?
The text was updated successfully, but these errors were encountered: