Skip to content

Commit

Permalink
Merge pull request #105 from edgenai/feat/embeddings
Browse files Browse the repository at this point in the history
Feat/embeddings
  • Loading branch information
pedro-devv authored Mar 4, 2024
2 parents f1f1a17 + 3504577 commit 6e3b7f9
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 6 deletions.
8 changes: 8 additions & 0 deletions crates/edgen_core/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub enum LLMEndpointError {
Load(String),
#[error("failed to create a new session: {0}")]
SessionCreationFailed(String),
#[error("failed to create embeddings: {0}")]
Embeddings(String), // Embeddings may involve session creation, advancing, and other things, so it should have its own error
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -80,6 +82,12 @@ pub trait LLMEndpoint {
args: CompletionArgs,
) -> BoxedFuture<Result<Box<dyn Stream<Item = String> + Unpin + Send>, LLMEndpointError>>;

fn embeddings<'a>(
&'a self,
model_path: impl AsRef<Path> + Send + 'a,
inputs: Vec<String>,
) -> BoxedFuture<Result<Vec<Vec<f32>>, LLMEndpointError>>;

/// Unloads everything from memory.
fn reset(&self);
}
Expand Down
17 changes: 15 additions & 2 deletions crates/edgen_core/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ pub struct SettingsParams {
/// The audio transcription repo that Edgen will use for downloads
pub audio_transcriptions_model_repo: String,

// TODO temporary, until the model parameter in incoming requests can be parsed into local paths
pub embeddings_models_dir: String,
/// The embeddings model that Edgen will use when the user does not provide a model
pub embeddings_model_name: String,
/// The embeddings repo that Edgen will use for downloads
pub embeddings_model_repo: String,

/// The policy used to decided if models/session should be allocated and run on acceleration
/// hardware.
pub gpu_policy: DevicePolicy,
Expand Down Expand Up @@ -214,12 +221,14 @@ impl Default for SettingsParams {
let data_dir = PROJECT_DIRS.data_dir();
let chat_completions_dir = data_dir.join(Path::new("models/chat/completions"));
let audio_transcriptions_dir = data_dir.join(Path::new("models/audio/transcriptions"));
let embeddings_dir = data_dir.join(Path::new("models/embeddings"));

let chat_completions_str = chat_completions_dir.into_os_string().into_string().unwrap();
let audio_transcriptions_str = audio_transcriptions_dir
.into_os_string()
.into_string()
.unwrap();
let embeddings_str = embeddings_dir.into_os_string().into_string().unwrap();

let cpus = num_cpus::get_physical();
let threads = if cpus > 1 { cpus - 1 } else { 1 };
Expand All @@ -230,10 +239,13 @@ impl Default for SettingsParams {
default_uri: "http://127.0.0.1:33322".to_string(),
chat_completions_model_name: "neural-chat-7b-v3-3.Q4_K_M.gguf".to_string(),
chat_completions_model_repo: "TheBloke/neural-chat-7B-v3-3-GGUF".to_string(),
chat_completions_models_dir: chat_completions_str,
audio_transcriptions_model_name: "ggml-distil-small.en.bin".to_string(),
audio_transcriptions_model_repo: "distil-whisper/distil-small.en".to_string(),
chat_completions_models_dir: chat_completions_str,
audio_transcriptions_models_dir: audio_transcriptions_str,
embeddings_model_name: "nomic-embed-text-v1.5.f16.gguf".to_string(),
embeddings_model_repo: "nomic-ai/nomic-embed-text-v1.5-GGUF".to_string(),
embeddings_models_dir: embeddings_str,
// TODO detect if the system has acceleration hardware to decide the default
gpu_policy: DevicePolicy::AlwaysDevice {
overflow_to_cpu: true,
Expand Down Expand Up @@ -325,7 +337,8 @@ impl DerefMut for SettingsInner {

pub struct Settings {
inner: Arc<RwLock<SettingsInner>>,
_watcher: PollWatcher, // we use a PollWatcher because it observes the path, not the inode
_watcher: PollWatcher,
// we use a PollWatcher because it observes the path, not the inode
handler: UpdateHandler,
}

Expand Down
34 changes: 33 additions & 1 deletion crates/edgen_rt_llama_cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use futures::executor::block_on;
use futures::Stream;
use llama_cpp::standard_sampler::StandardSampler;
use llama_cpp::{
CompletionHandle, LlamaModel, LlamaParams, LlamaSession, SessionParams, TokensToStrings,
CompletionHandle, EmbeddingsParams, LlamaModel, LlamaParams, LlamaSession, SessionParams,
TokensToStrings,
};
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -91,6 +92,15 @@ impl LlamaCppEndpoint {
let model = self.get(model_path).await;
model.stream_chat_completions(args).await
}

async fn async_embeddings(
&self,
model_path: impl AsRef<Path>,
inputs: Vec<String>,
) -> Result<Vec<Vec<f32>>, LLMEndpointError> {
let model = self.get(model_path).await;
model.embeddings(inputs).await
}
}

impl LLMEndpoint for LlamaCppEndpoint {
Expand All @@ -112,6 +122,15 @@ impl LLMEndpoint for LlamaCppEndpoint {
Box::new(pinned)
}

fn embeddings<'a>(
&'a self,
model_path: impl AsRef<Path> + Send + 'a,
inputs: Vec<String>,
) -> BoxedFuture<Result<Vec<Vec<f32>>, LLMEndpointError>> {
let pinned = Box::pin(self.async_embeddings(model_path, inputs));
Box::new(pinned)
}

fn reset(&self) {
self.models.clear();
}
Expand Down Expand Up @@ -317,6 +336,19 @@ impl UnloadingModel {
))
}
}

async fn embeddings(&self, inputs: Vec<String>) -> Result<Vec<Vec<f32>>, LLMEndpointError> {
let threads = SETTINGS.read().await.read().await.auto_threads(false);
let mut params = EmbeddingsParams::default();
params.n_threads = threads;
params.n_threads_batch = threads;

let (_model_signal, model_guard) = get_or_init_model(&self.model, &self.path).await?;
model_guard
.embeddings_async(&inputs, params)
.await
.map_err(move |e| LLMEndpointError::Embeddings(e.to_string()))
}
}

impl Drop for UnloadingModel {
Expand Down
4 changes: 4 additions & 0 deletions crates/edgen_server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ mod whisper;
openai_shim::FunctionStub,
openai_shim::AssistantFunctionStub,
openai_shim::AssistantToolCall,
openai_shim::CreateEmbeddingsRequest,
openai_shim::EmbeddingsResponse,
openai_shim::Embedding,
openai_shim::EmbeddingsUsage,
openai_shim::CreateTranscriptionRequest,
openai_shim::TranscriptionResponse,
openai_shim::TranscriptionError,
Expand Down
14 changes: 14 additions & 0 deletions crates/edgen_server/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ pub async fn chat_completion_stream(
))
}

pub async fn embeddings(
model: Model,
input: Vec<String>,
) -> Result<Vec<Vec<f32>>, LLMEndpointError> {
ENDPOINT
.embeddings(
model
.file_path()
.map_err(move |e| LLMEndpointError::Load(e.to_string()))?,
input,
)
.await
}

pub async fn reset_environment() {
ENDPOINT.reset()
}
173 changes: 170 additions & 3 deletions crates/edgen_server/src/openai_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use axum::response::{IntoResponse, Response, Sse};
use axum::Json;
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use derive_more::{Deref, DerefMut, From};
use edgen_core::llm::{CompletionArgs, LLMEndpointError};
use either::Either;
use futures::{Stream, StreamExt, TryStream};
use serde_derive::{Deserialize, Serialize};
Expand All @@ -37,9 +36,11 @@ use tracing::error;
use utoipa::ToSchema;
use uuid::Uuid;

use edgen_core::llm::{CompletionArgs, LLMEndpointError};
use edgen_core::settings::SETTINGS;
use edgen_core::whisper::WhisperEndpointError;

use crate::llm::embeddings;
use crate::model::{Model, ModelError, ModelKind};

/// The plaintext or image content of a [`ChatMessage`] within a [`CreateChatCompletionRequest`].
Expand Down Expand Up @@ -699,6 +700,172 @@ pub async fn chat_completions(
Ok(response)
}

/// A request to generate embeddings for one or more pieces of text.
///
/// An `axum` handler, [`create_embeddings`][create_embeddings], is provided to handle this request.
///
/// See [the documentation for creating transcriptions][openai] for more details.
///
/// [embeddings]: fn.create_embeddings.html
/// [openai]: https://platform.openai.com/docs/api-reference/embeddings/create
#[derive(Serialize, Deserialize, ToSchema)]
pub struct CreateEmbeddingsRequest<'a> {
/// The text input to embed as either a string or an array of strings.
#[serde(with = "either::serde_untagged")]
#[schema(value_type = String)]
pub input: Either<Cow<'a, str>, Vec<Cow<'a, str>>>,

/// ID of the model to use.
#[schema(value_type = String)]
pub model: Cow<'a, str>,

/// The format to return the embeddings in. Can be either `float` or `base64`.
#[schema(value_type = String)]
pub encoding_format: Option<Cow<'a, str>>,

/// The number of dimensions the resulting output embeddings should have. Only supported in some models.
pub dimensions: Option<usize>,
}

/// The return type of [`create_embeddings`].
#[derive(Serialize, Deserialize, ToSchema)]
pub struct EmbeddingsResponse {
/// Always `"list"`.
pub object: String,

/// The generated embeddings.
pub embeddings: Vec<Embedding>,

/// The model used for generation.
pub model: String,

/// The usage statistics of the request.
pub usage: EmbeddingsUsage,
}

/// Represents an embedding vector returned by embedding endpoint.
///
/// See [the documentation for creating transcriptions][openai] for more details.
///
/// [openai]: https://platform.openai.com/docs/api-reference/embeddings/object
#[derive(Serialize, Deserialize, ToSchema)]
pub struct Embedding {
/// Always `"embedding"`.
pub object: String,

/// The embedding vector, which is a list of floats. The length of vector depends on the model.
pub embedding: Vec<f32>,

/// The index of the embedding in the list of embeddings.
pub index: usize,
}

/// The usage statistics of the request.
#[derive(Serialize, Deserialize, ToSchema)]
pub struct EmbeddingsUsage {
// TODO doc
/// ???
pub prompt_tokens: usize,

// TODO doc
/// ???
pub total_tokens: usize,
}

// TODO change to use a dedicated error type, or make a common error type
/// POST `/v1/embeddings`: generates embeddings for the provided text.
///
/// See [the original OpenAI API specification][openai], which this endpoint is compatible with.
///
/// [openai]: https://platform.openai.com/docs/api-reference/embeddings/create
///
/// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`ChatCompletionError`]
/// to the peer.
#[utoipa::path(
post,
path = "/embeddings",
request_body = CreateEmbeddingsRequest,
responses(
(status = 200, description = "OK", body = EmbeddingsResponse),
(status = 500, description = "unexpected internal server error", body = ChatCompletionError)
),
)]
pub async fn create_embeddings(
Json(req): Json<CreateEmbeddingsRequest<'_>>,
) -> Result<impl IntoResponse, ChatCompletionError> {
// For MVP1, the model string in the request is *always* ignored.
let model_name = SETTINGS
.read()
.await
.read()
.await
.embeddings_model_name
.trim()
.to_string();
let repo = SETTINGS
.read()
.await
.read()
.await
.embeddings_model_repo
.trim()
.to_string();
let dir = SETTINGS
.read()
.await
.read()
.await
.embeddings_models_dir
.trim()
.to_string();

if model_name.is_empty() {
return Err(ChatCompletionError::ProhibitedName {
model_name,
reason: Cow::Borrowed("Empty model name in config"),
});
}
if dir.is_empty() {
return Err(ChatCompletionError::ProhibitedName {
model_name: dir,
reason: Cow::Borrowed("Empty model directory in config"),
});
}

let mut model = Model::new(ModelKind::LLM, &model_name, &repo, &PathBuf::from(&dir));

model
.preload()
.await
.map_err(move |_| ChatCompletionError::NoSuchModel {
model_name: model_name.to_string(),
})?;

let input = req.input.either(
move |s| vec![s.to_string()],
move |v| v.iter().map(move |s| s.to_string()).collect(),
);
let mut res = embeddings(model, input).await?;

Ok(Json(EmbeddingsResponse {
object: "list".to_string(),
embeddings: res
.drain(..)
.enumerate()
.map(move |(index, embedding)| Embedding {
object: "embedding".to_string(),
embedding,
index,
})
.collect(),
model: req.model.to_string(),
usage: EmbeddingsUsage {
prompt_tokens: 0,
total_tokens: 0,
},
}))
}

/// A request to transcribe an audio file into text in either the specified language, or whichever
/// language is automatically detected, if none is specified.
///
Expand Down Expand Up @@ -761,7 +928,7 @@ pub struct TranscriptionResponse {

/// The [`Uuid`] of a newly created session, present only if `create_session` in
/// [`CreateTranscriptionRequest`] is set to `true`. This additional member is **not normative**
/// with OpenAI's specification, as it is intended for **Edgen** specific functinality.
/// with OpenAI's specification, as it is intended for **Edgen** specific functionality.
#[serde(skip_serializing_if = "Option::is_none")]
pub session: Option<Uuid>,
}
Expand All @@ -772,7 +939,7 @@ pub struct TranscriptionResponse {
///
/// [openai]: https://platform.openai.com/docs/api-reference/audio/createTranscription
///
/// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`WhisperEndpointError`]
/// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`TranscriptionError`]
/// to the peer.
#[utoipa::path(
post,
Expand Down
2 changes: 2 additions & 0 deletions crates/edgen_server/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub fn routes() -> Router {
// -- AI endpoints -----------------------------------------------------
// ---- Chat -----------------------------------------------------------
.route("/v1/chat/completions", post(openai_shim::chat_completions))
// ---- Embeddings -----------------------------------------------------
.route("/v1/embeddings", post(openai_shim::create_embeddings))
// ---- Audio ----------------------------------------------------------
.route(
"/v1/audio/transcriptions",
Expand Down
Loading

0 comments on commit 6e3b7f9

Please sign in to comment.