Skip to content

Commit

Permalink
added more sampling options
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Feb 26, 2024
1 parent 996ce2c commit a25fdc1
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions llama-cpp-2/src/context/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,61 @@ impl LlamaContext<'_> {
LlamaToken(token)
}

/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
pub fn sample_tail_free(&self, token_data: &mut LlamaTokenDataArray, z: f32, min_keep: usize) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_tail_free(ctx, c_llama_token_data_array, z, min_keep);
});
}
}

/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
pub fn sample_typical(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_typical(ctx, c_llama_token_data_array, p, min_keep);
});
}
}

/// Nucleus sampling described in academic paper \"The Curious Case of Neural Text Degeneration\" https://arxiv.org/abs/1904.09751"
pub fn sample_top_p(&self, token_data: &mut LlamaTokenDataArray, p: f32, min_keep: usize) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_top_p(ctx, c_llama_token_data_array, p, min_keep);
});
}
}

/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841"
pub fn sample_min_p(
&self,
llama_token_data: &mut LlamaTokenDataArray,
p: f32,
min_keep: usize,
) {
let ctx = self.context.as_ptr();
unsafe {
llama_token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_min_p(ctx, c_llama_token_data_array, p, min_keep);
});
}
}

/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
pub fn sample_top_k(&self, token_data: &mut LlamaTokenDataArray, k: i32, min_keep: usize) {
let ctx = self.context.as_ptr();
unsafe {
token_data.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
llama_cpp_sys_2::llama_sample_top_k(ctx, c_llama_token_data_array, k, min_keep);
});
}
}

/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
pub fn sample_token_softmax(&self, token_data: &mut LlamaTokenDataArray) {
let ctx = self.context.as_ptr();
Expand Down

0 comments on commit a25fdc1

Please sign in to comment.