Skip to content

Commit

Permalink
Expose functions llama_load_session_file and llama_save_session_file
Browse files Browse the repository at this point in the history
  • Loading branch information
zh217 committed Feb 27, 2024
1 parent 1c2306f commit b8fad09
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
1 change: 1 addition & 0 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::slice;
pub mod kv_cache;
pub mod params;
pub mod sample;
pub mod session;

/// Safe wrapper around `llama_context`.
#[allow(clippy::module_name_repetitions)]
Expand Down
77 changes: 77 additions & 0 deletions llama-cpp-2/src/context/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//! utilities for working with session files
use std::ffi::{CString, NulError};
use std::path::{Path, PathBuf};
use crate::context::LlamaContext;
use crate::token::LlamaToken;

#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum SaveSessionError {
#[error("Failed to save session file")]
FailedToSave,

#[error("null byte in string {0}")]
NullError(#[from] NulError),

#[error("failed to convert path {0} to str")]
PathToStrError(PathBuf),
}

#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LoadSessionError {
#[error("Failed to load session file")]
FailedToLoad,

#[error("null byte in string {0}")]
NullError(#[from] NulError),

#[error("failed to convert path {0} to str")]
PathToStrError(PathBuf),
}

impl LlamaContext<'_> {
pub fn save_session_file(&self, path_session: impl AsRef<Path>, tokens: &[LlamaToken]) -> Result<(), SaveSessionError> {
let path = path_session.as_ref();
let path = path
.to_str()
.ok_or(SaveSessionError::PathToStrError(path.to_path_buf()))?;

let cstr = CString::new(path)?;

if unsafe {
llama_cpp_sys_2::llama_save_session_file(
self.context.as_ptr(),
cstr.as_ptr(),
tokens.as_ptr() as *const i32,
tokens.len())
} {
Ok(())
} else {
Err(SaveSessionError::FailedToSave)
}
}
pub fn load_session_file(&mut self, path_session: impl AsRef<Path>, max_tokens: usize) -> Result<Vec<LlamaToken>, LoadSessionError> {
let path = path_session.as_ref();
let path = path
.to_str()
.ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;

let cstr = CString::new(path)?;
let mut tokens = Vec::with_capacity(max_tokens);
let mut n_out = 0;

unsafe {
if llama_cpp_sys_2::llama_load_session_file(
self.context.as_ptr(),
cstr.as_ptr(),
tokens.as_mut_ptr() as *mut i32,
max_tokens,
&mut n_out) {
tokens.set_len(n_out);
Ok(tokens)
} else {
Err(LoadSessionError::FailedToLoad)
}
}
}
}

0 comments on commit b8fad09

Please sign in to comment.