Skip to content

Commit

Permalink
Add MistralAI API support
Browse files Browse the repository at this point in the history
add MistralAI support
  • Loading branch information
svilupp authored Dec 13, 2023
2 parents 7004044 + 29a299e commit 2f8e7b4
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 38 deletions.
10 changes: 8 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- Improved AICode parsing and error handling (eg, capture more REPL prompts, detect parsing errors earlier, parse more code fence types), including the option to remove unsafe code (eg, `Pkg.add("SomePkg")`) with `AICode(msg; skip_unsafe=true, vebose=true)`
- Added new prompt templates: `JuliaRecapTask`, `JuliaRecapCoTTask`, `JuliaExpertTestCode` and updated `JuliaExpertCoTTask` to be more robust against early stopping for smaller OSS models

### Fixed

## [0.4.0]

### Added
- Improved AICode parsing and error handling (eg, capture more REPL prompts, detect parsing errors earlier, parse more code fence types), including the option to remove unsafe code (eg, `Pkg.add("SomePkg")`) with `AICode(msg; skip_unsafe=true, vebose=true)`
- Added new prompt templates: `JuliaRecapTask`, `JuliaRecapCoTTask`, `JuliaExpertTestCode` and updated `JuliaExpertCoTTask` to be more robust against early stopping for smaller OSS models
- Added support for MistralAI API via the MistralOpenAISchema(). All their standard models have been registered, so you should be able to just use `model="mistral-tiny` in your `aigenerate` calls without any further changes. Remember to either provide `api_kwargs.api_key` or ensure you have ENV variable `MISTRALAI_API_KEY` set.
- Added support for any OpenAI-compatible API via `schema=CustomOpenAISchema()`. All you have to do is to provide your `api_key` and `url` (base URL of the API) in the `api_kwargs` keyword argument. This option is useful if you use [Perplexity.ai](https://docs.perplexity.ai/), [Fireworks.ai](https://app.fireworks.ai/), or any other similar services.

## [0.3.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.4.0-DEV"
version = "0.4.0"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand Down
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ For more practical examples, see the `examples/` folder and the [Advanced Exampl
- [Data Extraction](#data-extraction)
- [OCR and Image Comprehension](#ocr-and-image-comprehension)
- [Using Ollama models](#using-ollama-models)
- [Using MistralAI API and other OpenAI-compatible APIs](#using-mistralai-api-and-other-openai-compatible-apis)
- [More Examples](#more-examples)
- [Package Interface](#package-interface)
- [Frequently Asked Questions](#frequently-asked-questions)
Expand Down Expand Up @@ -395,6 +396,38 @@ msg.content # 4096×2 Matrix{Float64}:

If you're getting errors, check that Ollama is running - see the [Setup Guide for Ollama](#setup-guide-for-ollama) section below.

### Using MistralAI API and other OpenAI-compatible APIs

Mistral models have long been dominating the open-source space. They are now available via their API, so you can use them with PromptingTools.jl!

```julia
msg = aigenerate("Say hi!"; model="mistral-tiny")
```

It all just works, because we have registered the models in the `PromptingTools.MODEL_REGISTRY`! There are currently 4 models available: `mistral-tiny`, `mistral-small`, `mistral-medium`, `mistral-embed`.

Under the hood, we use a dedicated schema `MistralOpenAISchema` that leverages most of the OpenAI-specific code base, so you can always provide that explicitly as the first argument:

```julia
const PT = PromptingTools
msg = aigenerate(PT.MistralOpenAISchema(), "Say Hi!"; model="mistral-tiny", api_key=ENV["MISTRALAI_API_KEY"])
```
As you can see, we can load your API key either from the ENV or via the Preferences.jl mechanism (see `?PREFERENCES` for more information).

But MistralAI are not the only ones! There are many other exciting providers, eg, [Perplexity.ai](https://docs.perplexity.ai/), [Fireworks.ai](https://app.fireworks.ai/).
As long as they are compatible with the OpenAI API (eg, sending `messages` with `role` and `content` keys), you can use them with PromptingTools.jl by using `schema = CustomOpenAISchema()`:

```julia
# Set your API key and the necessary base URL for the API
api_key = "..."
prompt = "Say hi!"
msg = aigenerate(PT.CustomOpenAISchema(), prompt; model="my_model", api_key, api_kwargs=(; url="http://localhost:8081"))
```

As you can see, it also works for any local models that you might have running on your computer!

Note: At the moment, we only support `aigenerate` and `aiembed` functions for MistralAI and other OpenAI-compatible APIs. We plan to extend the support in the future.

### More Examples

TBU...
Expand Down Expand Up @@ -529,6 +562,8 @@ Resources:

### Configuring the Environment Variable for API Key

This is a guide for OpenAI's API key, but it works for any other API key you might need (eg, `MISTRALAI_API_KEY` for MistralAI API).

To use the OpenAI API with PromptingTools.jl, set your API key as an environment variable:

```julia
Expand Down
36 changes: 35 additions & 1 deletion docs/src/examples/readme_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,4 +286,38 @@ msg = aiembed(schema, ["Embed me", "Embed me"]; model="openhermes2.5-mistral")
msg.content # 4096×2 Matrix{Float64}:
```

If you're getting errors, check that Ollama is running - see the [Setup Guide for Ollama](#setup-guide-for-ollama) section below.
If you're getting errors, check that Ollama is running - see the [Setup Guide for Ollama](#setup-guide-for-ollama) section below.

## Using MistralAI API and other OpenAI-compatible APIs

Mistral models have long been dominating the open-source space. They are now available via their API, so you can use them with PromptingTools.jl!

```julia
msg = aigenerate("Say hi!"; model="mistral-tiny")
# [ Info: Tokens: 114 @ Cost: $0.0 in 0.9 seconds
# AIMessage("Hello there! I'm here to help answer any questions you might have, or assist you with tasks to the best of my abilities. How can I be of service to you today? If you have a specific question, feel free to ask and I'll do my best to provide accurate and helpful information. If you're looking for general assistance, I can help you find resources or information on a variety of topics. Let me know how I can help.")
```

It all just works, because we have registered the models in the `PromptingTools.MODEL_REGISTRY`! There are currently 4 models available: `mistral-tiny`, `mistral-small`, `mistral-medium`, `mistral-embed`.

Under the hood, we use a dedicated schema `MistralOpenAISchema` that leverages most of the OpenAI-specific code base, so you can always provide that explicitly as the first argument:

```julia
const PT = PromptingTools
msg = aigenerate(PT.MistralOpenAISchema(), "Say Hi!"; model="mistral-tiny", api_key=ENV["MISTRALAI_API_KEY"])
```
As you can see, we can load your API key either from the ENV or via the Preferences.jl mechanism (see `?PREFERENCES` for more information).

But MistralAI are not the only ones! There are many other exciting providers, eg, [Perplexity.ai](https://docs.perplexity.ai/), [Fireworks.ai](https://app.fireworks.ai/).
As long as they are compatible with the OpenAI API (eg, sending `messages` with `role` and `content` keys), you can use them with PromptingTools.jl by using `schema = CustomOpenAISchema()`:

```julia
# Set your API key and the necessary base URL for the API
api_key = "..."
prompt = "Say hi!"
msg = aigenerate(PT.CustomOpenAISchema(), prompt; model="my_model", api_key, api_kwargs=(; url="http://localhost:8081"))
```

As you can see, it also works for any local models that you might have running on your computer!

Note: At the moment, we only support `aigenerate` and `aiembed` functions for MistralAI and other OpenAI-compatible APIs. We plan to extend the support in the future.
2 changes: 2 additions & 0 deletions docs/src/frequently_asked_questions.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ Resources:

## Configuring the Environment Variable for API Key

This is a guide for OpenAI's API key, but it works for any other API key you might need (eg, `MISTRALAI_API_KEY` for MistralAI API).

To use the OpenAI API with PromptingTools.jl, set your API key as an environment variable:

```julia
Expand Down
42 changes: 42 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,48 @@ struct OpenAISchema <: AbstractOpenAISchema end
inputs::Any = nothing
end

"""
CustomOpenAISchema
CustomOpenAISchema() allows user to call any OpenAI-compatible API.
All user needs to do is to pass this schema as the first argument and provide the BASE URL of the API to call (`api_kwargs.url`).
# Example
Assumes that we have a local server running at `http://localhost:8081`:
```julia
api_key = "..."
prompt = "Say hi!"
msg = aigenerate(CustomOpenAISchema(), prompt; model="my_model", api_key, api_kwargs=(; url="http://localhost:8081"))
```
"""
struct CustomOpenAISchema <: AbstractOpenAISchema end

"""
MistralOpenAISchema
MistralOpenAISchema() allows user to call MistralAI API known for mistral and mixtral models.
It's a flavor of CustomOpenAISchema() with a url preset to `https://api.mistral.ai`.
Most models have been registered, so you don't even have to specify the schema
# Example
Let's call `mistral-tiny` model:
```julia
api_key = "..." # can be set via ENV["MISTRAL_API_KEY"] or via our preference system
msg = aigenerate("Say hi!"; model="mistral_tiny", api_key)
```
See `?PREFERENCES` for more details on how to set your API key permanently.
"""
struct MistralOpenAISchema <: AbstractOpenAISchema end

abstract type AbstractChatMLSchema <: AbstractPromptSchema end
"""
ChatMLSchema is used by many open-source chatbots, by OpenAI models (under the hood) and by several models and inferfaces (eg, Ollama, vLLM)
Expand Down
175 changes: 145 additions & 30 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,151 @@ function render(schema::AbstractOpenAISchema,
return conversation
end

## OpenAI.jl back-end
## Types
# "Providers" are a way to use other APIs that are compatible with OpenAI API specs, eg, Azure and mamy more
# Define our sub-type to distinguish it from other OpenAI.jl providers
abstract type AbstractCustomProvider <: OpenAI.AbstractOpenAIProvider end
Base.@kwdef struct CustomProvider <: AbstractCustomProvider
api_key::String = ""
base_url::String = "http://localhost:8080"
api_version::String = ""
end
function OpenAI.build_url(provider::AbstractCustomProvider, api::AbstractString)
string(provider.base_url, "/", api)
end
function OpenAI.auth_header(provider::AbstractCustomProvider, api_key::AbstractString)
OpenAI.auth_header(OpenAI.OpenAIProvider(provider.api_key,
provider.base_url,
provider.api_version),
api_key)
end
## Extend OpenAI create_chat to allow for testing/debugging
# Default passthrough
function OpenAI.create_chat(schema::AbstractOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
kwargs...)
OpenAI.create_chat(api_key, model, conversation; kwargs...)
end

# Overload for testing/debugging
function OpenAI.create_chat(schema::TestEchoOpenAISchema, api_key::AbstractString,
model::AbstractString,
conversation; kwargs...)
schema.model_id = model
schema.inputs = conversation
return schema
end

"""
OpenAI.create_chat(schema::CustomOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String="http://localhost:8080",
kwargs...)
Dispatch to the OpenAI.create_chat function, for any OpenAI-compatible API.
It expects `url` keyword argument. Provide it to the `aigenerate` function via `api_kwargs=(; url="my-url")`
It will forward your query to the "chat/completions" endpoint of the base URL that you provided (=`url`).
"""
function OpenAI.create_chat(schema::CustomOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "http://localhost:8080",
kwargs...)
# Build the corresponding provider object
# Create chat will automatically pass our data to endpoint `/chat/completions`
provider = CustomProvider(; api_key, base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
end

"""
OpenAI.create_chat(schema::MistralOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String="https://api.mistral.ai/v1",
kwargs...)
Dispatch to the OpenAI.create_chat function, but with the MistralAI API parameters.
It tries to access the `MISTRALAI_API_KEY` ENV variable, but you can also provide it via the `api_key` keyword argument.
"""
function OpenAI.create_chat(schema::MistralOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.mistral.ai/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(MISTRALAI_API_KEY) ? api_key : MISTRALAI_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
end

# Extend OpenAI create_embeddings to allow for testing
function OpenAI.create_embeddings(schema::AbstractOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
kwargs...)
OpenAI.create_embeddings(api_key, docs, model; kwargs...)
end
function OpenAI.create_embeddings(schema::TestEchoOpenAISchema, api_key::AbstractString,
docs,
model::AbstractString; kwargs...)
schema.model_id = model
schema.inputs = docs
return schema
end
function OpenAI.create_embeddings(schema::CustomOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
url::String = "http://localhost:8080",
kwargs...)
# Build the corresponding provider object
# Create chat will automatically pass our data to endpoint `/embeddings`
provider = CustomProvider(; api_key, base_url = url)
OpenAI.create_embeddings(provider, docs, model; kwargs...)
end
function OpenAI.create_embeddings(schema::MistralOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
url::String = "https://api.mistral.ai/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(MISTRALAI_API_KEY) ? api_key : MISTRALAI_API_KEY,
base_url = url)
OpenAI.create_embeddings(provider, docs, model; kwargs...)
end

## Temporary fix -- it will be moved upstream
function OpenAI.create_embeddings(provider::AbstractCustomProvider,
input,
model_id::String = OpenAI.DEFAULT_EMBEDDING_MODEL_ID;
http_kwargs::NamedTuple = NamedTuple(),
kwargs...)
return OpenAI.openai_request("embeddings",
provider;
method = "POST",
http_kwargs = http_kwargs,
model = model_id,
input,
kwargs...)
end

## User-Facing API
"""
aigenerate(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE;
Expand Down Expand Up @@ -170,21 +315,6 @@ function aigenerate(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_

return output
end
# Extend OpenAI create_chat to allow for testing/debugging
function OpenAI.create_chat(schema::AbstractOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
kwargs...)
OpenAI.create_chat(api_key, model, conversation; kwargs...)
end
function OpenAI.create_chat(schema::TestEchoOpenAISchema, api_key::AbstractString,
model::AbstractString,
conversation; kwargs...)
schema.model_id = model
schema.inputs = conversation
return schema
end

"""
aiembed(prompt_schema::AbstractOpenAISchema,
Expand Down Expand Up @@ -268,21 +398,6 @@ function aiembed(prompt_schema::AbstractOpenAISchema,

return msg
end
# Extend OpenAI create_embeddings to allow for testing
function OpenAI.create_embeddings(schema::AbstractOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
kwargs...)
OpenAI.create_embeddings(api_key, docs, model; kwargs...)
end
function OpenAI.create_embeddings(schema::TestEchoOpenAISchema, api_key::AbstractString,
docs,
model::AbstractString; kwargs...)
schema.model_id = model
schema.inputs = docs
return schema
end

"""
aiclassify(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE;
Expand Down
Loading

0 comments on commit 2f8e7b4

Please sign in to comment.