diff --git a/crates/edgen_core/src/llm.rs b/crates/edgen_core/src/llm.rs index cac1044..3be07d4 100644 --- a/crates/edgen_core/src/llm.rs +++ b/crates/edgen_core/src/llm.rs @@ -39,6 +39,8 @@ pub enum LLMEndpointError { Load(String), #[error("failed to create a new session: {0}")] SessionCreationFailed(String), + #[error("failed to create embeddings: {0}")] + Embeddings(String), // Embeddings may involve session creation, advancing, and other things, so it should have its own error } #[derive(Debug, Clone)] @@ -80,6 +82,12 @@ pub trait LLMEndpoint { args: CompletionArgs, ) -> BoxedFuture + Unpin + Send>, LLMEndpointError>>; + fn embeddings<'a>( + &'a self, + model_path: impl AsRef + Send + 'a, + inputs: Vec, + ) -> BoxedFuture>, LLMEndpointError>>; + /// Unloads everything from memory. fn reset(&self); } diff --git a/crates/edgen_core/src/settings.rs b/crates/edgen_core/src/settings.rs index e781450..0c9d7f0 100644 --- a/crates/edgen_core/src/settings.rs +++ b/crates/edgen_core/src/settings.rs @@ -183,6 +183,13 @@ pub struct SettingsParams { /// The audio transcription repo that Edgen will use for downloads pub audio_transcriptions_model_repo: String, + // TODO temporary, until the model parameter in incoming requests can be parsed into local paths + pub embeddings_models_dir: String, + /// The embeddings model that Edgen will use when the user does not provide a model + pub embeddings_model_name: String, + /// The embeddings repo that Edgen will use for downloads + pub embeddings_model_repo: String, + /// The policy used to decided if models/session should be allocated and run on acceleration /// hardware. pub gpu_policy: DevicePolicy, @@ -214,12 +221,14 @@ impl Default for SettingsParams { let data_dir = PROJECT_DIRS.data_dir(); let chat_completions_dir = data_dir.join(Path::new("models/chat/completions")); let audio_transcriptions_dir = data_dir.join(Path::new("models/audio/transcriptions")); + let embeddings_dir = data_dir.join(Path::new("models/embeddings")); let chat_completions_str = chat_completions_dir.into_os_string().into_string().unwrap(); let audio_transcriptions_str = audio_transcriptions_dir .into_os_string() .into_string() .unwrap(); + let embeddings_str = embeddings_dir.into_os_string().into_string().unwrap(); let cpus = num_cpus::get_physical(); let threads = if cpus > 1 { cpus - 1 } else { 1 }; @@ -230,10 +239,13 @@ impl Default for SettingsParams { default_uri: "http://127.0.0.1:33322".to_string(), chat_completions_model_name: "neural-chat-7b-v3-3.Q4_K_M.gguf".to_string(), chat_completions_model_repo: "TheBloke/neural-chat-7B-v3-3-GGUF".to_string(), + chat_completions_models_dir: chat_completions_str, audio_transcriptions_model_name: "ggml-distil-small.en.bin".to_string(), audio_transcriptions_model_repo: "distil-whisper/distil-small.en".to_string(), - chat_completions_models_dir: chat_completions_str, audio_transcriptions_models_dir: audio_transcriptions_str, + embeddings_model_name: "nomic-embed-text-v1.5.f16.gguf".to_string(), + embeddings_model_repo: "nomic-ai/nomic-embed-text-v1.5-GGUF".to_string(), + embeddings_models_dir: embeddings_str, // TODO detect if the system has acceleration hardware to decide the default gpu_policy: DevicePolicy::AlwaysDevice { overflow_to_cpu: true, @@ -325,7 +337,8 @@ impl DerefMut for SettingsInner { pub struct Settings { inner: Arc>, - _watcher: PollWatcher, // we use a PollWatcher because it observes the path, not the inode + _watcher: PollWatcher, + // we use a PollWatcher because it observes the path, not the inode handler: UpdateHandler, } diff --git a/crates/edgen_rt_llama_cpp/src/lib.rs b/crates/edgen_rt_llama_cpp/src/lib.rs index 56272a5..64effbd 100644 --- a/crates/edgen_rt_llama_cpp/src/lib.rs +++ b/crates/edgen_rt_llama_cpp/src/lib.rs @@ -22,7 +22,8 @@ use futures::executor::block_on; use futures::Stream; use llama_cpp::standard_sampler::StandardSampler; use llama_cpp::{ - CompletionHandle, LlamaModel, LlamaParams, LlamaSession, SessionParams, TokensToStrings, + CompletionHandle, EmbeddingsParams, LlamaModel, LlamaParams, LlamaSession, SessionParams, + TokensToStrings, }; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::task::JoinHandle; @@ -91,6 +92,15 @@ impl LlamaCppEndpoint { let model = self.get(model_path).await; model.stream_chat_completions(args).await } + + async fn async_embeddings( + &self, + model_path: impl AsRef, + inputs: Vec, + ) -> Result>, LLMEndpointError> { + let model = self.get(model_path).await; + model.embeddings(inputs).await + } } impl LLMEndpoint for LlamaCppEndpoint { @@ -112,6 +122,15 @@ impl LLMEndpoint for LlamaCppEndpoint { Box::new(pinned) } + fn embeddings<'a>( + &'a self, + model_path: impl AsRef + Send + 'a, + inputs: Vec, + ) -> BoxedFuture>, LLMEndpointError>> { + let pinned = Box::pin(self.async_embeddings(model_path, inputs)); + Box::new(pinned) + } + fn reset(&self) { self.models.clear(); } @@ -317,6 +336,19 @@ impl UnloadingModel { )) } } + + async fn embeddings(&self, inputs: Vec) -> Result>, LLMEndpointError> { + let threads = SETTINGS.read().await.read().await.auto_threads(false); + let mut params = EmbeddingsParams::default(); + params.n_threads = threads; + params.n_threads_batch = threads; + + let (_model_signal, model_guard) = get_or_init_model(&self.model, &self.path).await?; + model_guard + .embeddings_async(&inputs, params) + .await + .map_err(move |e| LLMEndpointError::Embeddings(e.to_string())) + } } impl Drop for UnloadingModel { diff --git a/crates/edgen_server/src/lib.rs b/crates/edgen_server/src/lib.rs index f39ba50..7b1cde5 100644 --- a/crates/edgen_server/src/lib.rs +++ b/crates/edgen_server/src/lib.rs @@ -77,6 +77,10 @@ mod whisper; openai_shim::FunctionStub, openai_shim::AssistantFunctionStub, openai_shim::AssistantToolCall, + openai_shim::CreateEmbeddingsRequest, + openai_shim::EmbeddingsResponse, + openai_shim::Embedding, + openai_shim::EmbeddingsUsage, openai_shim::CreateTranscriptionRequest, openai_shim::TranscriptionResponse, openai_shim::TranscriptionError, diff --git a/crates/edgen_server/src/llm.rs b/crates/edgen_server/src/llm.rs index 740db1d..25c359e 100644 --- a/crates/edgen_server/src/llm.rs +++ b/crates/edgen_server/src/llm.rs @@ -59,6 +59,20 @@ pub async fn chat_completion_stream( )) } +pub async fn embeddings( + model: Model, + input: Vec, +) -> Result>, LLMEndpointError> { + ENDPOINT + .embeddings( + model + .file_path() + .map_err(move |e| LLMEndpointError::Load(e.to_string()))?, + input, + ) + .await +} + pub async fn reset_environment() { ENDPOINT.reset() } diff --git a/crates/edgen_server/src/openai_shim.rs b/crates/edgen_server/src/openai_shim.rs index b7110e1..c8a6ec7 100644 --- a/crates/edgen_server/src/openai_shim.rs +++ b/crates/edgen_server/src/openai_shim.rs @@ -26,7 +26,6 @@ use axum::response::{IntoResponse, Response, Sse}; use axum::Json; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use derive_more::{Deref, DerefMut, From}; -use edgen_core::llm::{CompletionArgs, LLMEndpointError}; use either::Either; use futures::{Stream, StreamExt, TryStream}; use serde_derive::{Deserialize, Serialize}; @@ -37,9 +36,11 @@ use tracing::error; use utoipa::ToSchema; use uuid::Uuid; +use edgen_core::llm::{CompletionArgs, LLMEndpointError}; use edgen_core::settings::SETTINGS; use edgen_core::whisper::WhisperEndpointError; +use crate::llm::embeddings; use crate::model::{Model, ModelError, ModelKind}; /// The plaintext or image content of a [`ChatMessage`] within a [`CreateChatCompletionRequest`]. @@ -699,6 +700,172 @@ pub async fn chat_completions( Ok(response) } +/// A request to generate embeddings for one or more pieces of text. +/// +/// An `axum` handler, [`create_embeddings`][create_embeddings], is provided to handle this request. +/// +/// See [the documentation for creating transcriptions][openai] for more details. +/// +/// [embeddings]: fn.create_embeddings.html +/// [openai]: https://platform.openai.com/docs/api-reference/embeddings/create +#[derive(Serialize, Deserialize, ToSchema)] +pub struct CreateEmbeddingsRequest<'a> { + /// The text input to embed as either a string or an array of strings. + #[serde(with = "either::serde_untagged")] + #[schema(value_type = String)] + pub input: Either, Vec>>, + + /// ID of the model to use. + #[schema(value_type = String)] + pub model: Cow<'a, str>, + + /// The format to return the embeddings in. Can be either `float` or `base64`. + #[schema(value_type = String)] + pub encoding_format: Option>, + + /// The number of dimensions the resulting output embeddings should have. Only supported in some models. + pub dimensions: Option, +} + +/// The return type of [`create_embeddings`]. +#[derive(Serialize, Deserialize, ToSchema)] +pub struct EmbeddingsResponse { + /// Always `"list"`. + pub object: String, + + /// The generated embeddings. + pub embeddings: Vec, + + /// The model used for generation. + pub model: String, + + /// The usage statistics of the request. + pub usage: EmbeddingsUsage, +} + +/// Represents an embedding vector returned by embedding endpoint. +/// +/// See [the documentation for creating transcriptions][openai] for more details. +/// +/// [openai]: https://platform.openai.com/docs/api-reference/embeddings/object +#[derive(Serialize, Deserialize, ToSchema)] +pub struct Embedding { + /// Always `"embedding"`. + pub object: String, + + /// The embedding vector, which is a list of floats. The length of vector depends on the model. + pub embedding: Vec, + + /// The index of the embedding in the list of embeddings. + pub index: usize, +} + +/// The usage statistics of the request. +#[derive(Serialize, Deserialize, ToSchema)] +pub struct EmbeddingsUsage { + // TODO doc + /// ??? + pub prompt_tokens: usize, + + // TODO doc + /// ??? + pub total_tokens: usize, +} + +// TODO change to use a dedicated error type, or make a common error type +/// POST `/v1/embeddings`: generates embeddings for the provided text. +/// +/// See [the original OpenAI API specification][openai], which this endpoint is compatible with. +/// +/// [openai]: https://platform.openai.com/docs/api-reference/embeddings/create +/// +/// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`ChatCompletionError`] +/// to the peer. +#[utoipa::path( +post, +path = "/embeddings", +request_body = CreateEmbeddingsRequest, +responses( +(status = 200, description = "OK", body = EmbeddingsResponse), +(status = 500, description = "unexpected internal server error", body = ChatCompletionError) +), +)] +pub async fn create_embeddings( + Json(req): Json>, +) -> Result { + // For MVP1, the model string in the request is *always* ignored. + let model_name = SETTINGS + .read() + .await + .read() + .await + .embeddings_model_name + .trim() + .to_string(); + let repo = SETTINGS + .read() + .await + .read() + .await + .embeddings_model_repo + .trim() + .to_string(); + let dir = SETTINGS + .read() + .await + .read() + .await + .embeddings_models_dir + .trim() + .to_string(); + + if model_name.is_empty() { + return Err(ChatCompletionError::ProhibitedName { + model_name, + reason: Cow::Borrowed("Empty model name in config"), + }); + } + if dir.is_empty() { + return Err(ChatCompletionError::ProhibitedName { + model_name: dir, + reason: Cow::Borrowed("Empty model directory in config"), + }); + } + + let mut model = Model::new(ModelKind::LLM, &model_name, &repo, &PathBuf::from(&dir)); + + model + .preload() + .await + .map_err(move |_| ChatCompletionError::NoSuchModel { + model_name: model_name.to_string(), + })?; + + let input = req.input.either( + move |s| vec![s.to_string()], + move |v| v.iter().map(move |s| s.to_string()).collect(), + ); + let mut res = embeddings(model, input).await?; + + Ok(Json(EmbeddingsResponse { + object: "list".to_string(), + embeddings: res + .drain(..) + .enumerate() + .map(move |(index, embedding)| Embedding { + object: "embedding".to_string(), + embedding, + index, + }) + .collect(), + model: req.model.to_string(), + usage: EmbeddingsUsage { + prompt_tokens: 0, + total_tokens: 0, + }, + })) +} + /// A request to transcribe an audio file into text in either the specified language, or whichever /// language is automatically detected, if none is specified. /// @@ -761,7 +928,7 @@ pub struct TranscriptionResponse { /// The [`Uuid`] of a newly created session, present only if `create_session` in /// [`CreateTranscriptionRequest`] is set to `true`. This additional member is **not normative** - /// with OpenAI's specification, as it is intended for **Edgen** specific functinality. + /// with OpenAI's specification, as it is intended for **Edgen** specific functionality. #[serde(skip_serializing_if = "Option::is_none")] pub session: Option, } @@ -772,7 +939,7 @@ pub struct TranscriptionResponse { /// /// [openai]: https://platform.openai.com/docs/api-reference/audio/createTranscription /// -/// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`WhisperEndpointError`] +/// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`TranscriptionError`] /// to the peer. #[utoipa::path( post, diff --git a/crates/edgen_server/src/routes.rs b/crates/edgen_server/src/routes.rs index d2384a8..f9f1fe5 100644 --- a/crates/edgen_server/src/routes.rs +++ b/crates/edgen_server/src/routes.rs @@ -31,6 +31,8 @@ pub fn routes() -> Router { // -- AI endpoints ----------------------------------------------------- // ---- Chat ----------------------------------------------------------- .route("/v1/chat/completions", post(openai_shim::chat_completions)) + // ---- Embeddings ----------------------------------------------------- + .route("/v1/embeddings", post(openai_shim::create_embeddings)) // ---- Audio ---------------------------------------------------------- .route( "/v1/audio/transcriptions", diff --git a/docs/src/app/api-reference/embeddings/page.mdx b/docs/src/app/api-reference/embeddings/page.mdx new file mode 100644 index 0000000..6f4655c --- /dev/null +++ b/docs/src/app/api-reference/embeddings/page.mdx @@ -0,0 +1,105 @@ +export const metadata = { + title: 'Embeddings', + description: 'Generate embeddings', +} + +# Embeddings + +Generate embeddings from text. {{ className: 'lead' }} + +--- + +## Create embeddings {{ tag: 'POST', label: 'http://localhost:33322/v1/embeddings' }} + + + + Given a list of messages belonging to a chat history, generate a response. + + ### Required attributes + + + + One or multiple pieces of text from which embeddings will be generated. For each piece of text, one embedding is generated. + + + + + + The model used for chat completions. **WARNING**: currently, this attribute is **ignored** and the **default model is used**. + + + + + + + + + ```bash {{ title: 'cURL' }} + curl http://localhost:33322/v1/embeddings \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer no-key-required" \ + -d '{ + "model": "default", + "input": "Hello World!" + }' + ``` + + ```python + from edgen import Edgen + client = Edgen() + + embeddings = client.embeddings.create( + model="default", + input="Who is John Doe?" + ) + + for item in completion.data: + print(data.embedding) + ``` + + ```ts + import Edgen from "edgen"; + + const client = new Edgen(); + + async function main() { + const embeddings = await client.embeddings.create({ + model: "default", + messages: "Who is Foo Bar?" + }); + + for await (const item of embeddings.data) { + console.log(item.embedding.content); + } + } + + main(); + ``` + + + + ```json {{ title: 'Response' }} + { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ + 0.0023064255, + -0.009327292, + .... + -0.0028842222, + ], + "index": 0 + } + ], + "model": "default", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + ``` + + +