Skip to content

Commit

Permalink
Add more API Providers (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Feb 23, 2024
1 parent 463a830 commit c811407
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `run_id` - ID of the AI API call
- `sample_id` - ID of the sample in the batch if you requested multiple completions, otherwise `sample_id==nothing` (they will have the same `run_id`)
- `finish_reason` - the reason why the AI stopped generating the sequence (eg, "stop", "length") to provide more visibility for the user
- Support for Fireworks.ai and Together.ai providers for fast and easy access to open-source models. Requires environment variables `FIREWORKS_API_KEY` and `TOGETHER_API_KEY` to be set, respectively. See the `?FireworksOpenAISchema` and `?TogetherOpenAISchema` for more information.

### Fixed

Expand Down
64 changes: 63 additions & 1 deletion docs/src/examples/working_with_custom_apis.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,66 @@ ai"Say hi to the llama!"dllama

You can use `aiembed` as well.

Find more information [here](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html).
Find more information [here](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html).

## Using Together.ai

You can also use the Together.ai API with PromptingTools.jl.
It requires you to set ENV variable `TOGETHER_API_KEY`.

The corresponding schema is `TogetherOpenAISchema`, but we have registered one model for you, so you can use it as usual.
Alias "tmixtral" (T for Together.ai and mixtral for the model name) is already set for you.

```julia
msg = aigenerate("Say hi"; model="tmixtral")
## [ Info: Tokens: 87 @ Cost: \$0.0001 in 5.1 seconds
## AIMessage("Hello! I'm here to help you. Is there something specific you'd like to know or discuss? I can provide information on a wide range of topics, assist with tasks, and even engage in a friendly conversation. Let me know how I can best assist you today.")
```

For embedding a text, use `aiembed`:

```julia
aiembed(PT.TogetherOpenAISchema(), "embed me"; model="BAAI/bge-large-en-v1.5")
```
Note: You can register the model with `PT.register_model!` and use it as usual.

## Using Fireworks.ai

You can also use the Fireworks.ai API with PromptingTools.jl.
It requires you to set ENV variable `FIREWORKS_API_KEY`.

The corresponding schema is `FireworksOpenAISchema`, but we have registered one model for you, so you can use it as usual.
Alias "fmixtral" (F for Fireworks.ai and mixtral for the model name) is already set for you.

```julia
msg = aigenerate("Say hi"; model="fmixtral")
## [ Info: Tokens: 78 @ Cost: \$0.0001 in 0.9 seconds
## AIMessage("Hello! I'm glad you're here. I'm here to help answer any questions you have to the best of my ability. Is there something specific you'd like to know or discuss? I can assist with a wide range of topics, so feel free to ask me anything!")
```

In addition, at the time of writing (23rd Feb 2024), Fireworks is providing access to their new _function calling_ model (fine-tuned Mixtral) **for free**.

Try it with `aiextract` for structured extraction (model is aliased as `firefunction`):

```julia
"""
Extract the food from the sentence. Extract any provided adjectives for the food as well.
Example: "I am eating a crunchy bread." -> Food("bread", ["crunchy"])
"""
struct Food
name::String
adjectives::Union{Nothing,Vector{String}}
end
prompt = "I just ate a delicious and juicy apple."
msg = aiextract(prompt; return_type=Food, model="firefunction")
msg.content
# Output: Food("apple", ["delicious", "juicy"])
```

For embedding a text, use `aiembed`:

```julia
aiembed(PT.FireworksOpenAISchema(), "embed me"; model="nomic-ai/nomic-embed-text-v1.5")
```
Note: You can register the model with `PT.register_model!` and use it as usual.
2 changes: 1 addition & 1 deletion docs/src/frequently_asked_questions.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ There are three ways how you can customize your workflows (especially when you u
2) Register your model and its associated schema (`PT.register_model!(; name="123", schema=PT.OllamaSchema())`). You won't have to specify the schema anymore only the model name. See [Working with Ollama](#working-with-ollama) for more information.
3) Override your default model (`PT.MODEL_CHAT`) and schema (`PT.PROMPT_SCHEMA`). It can be done persistently with Preferences, eg, `PT.set_preferences!("PROMPT_SCHEMA" => "OllamaSchema", "MODEL_CHAT"=>"llama2")`.

## How to have a Multi-turn Conversations?
## How to have Multi-turn Conversations?

Let's say you would like to respond back to a model's response. How to do it?

Expand Down
30 changes: 30 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,36 @@ Requires two environment variables to be set:
"""
struct DatabricksOpenAISchema <: AbstractOpenAISchema end

"""
FireworksOpenAISchema
Schema to call the [Fireworks.ai](https://fireworks.ai/) API.
Links:
- [Get your API key](https://fireworks.ai/api-keys)
- [API Reference](https://readme.fireworks.ai/reference/createchatcompletion)
- [Available models](https://fireworks.ai/models)
Requires one environment variables to be set:
- `FIREWORKS_API_KEY`: Your API key
"""
struct FireworksOpenAISchema <: AbstractOpenAISchema end

"""
TogetherOpenAISchema
Schema to call the [Together.ai](https://www.together.ai/) API.
Links:
- [Get your API key](https://api.together.xyz/settings/api-keys)
- [API Reference](https://docs.together.ai/docs/openai-api-compatibility)
- [Available models](https://docs.together.ai/docs/inference-models)
Requires one environment variables to be set:
- `TOGETHER_API_KEY`: Your API key
"""
struct TogetherOpenAISchema <: AbstractOpenAISchema end

abstract type AbstractOllamaSchema <: AbstractPromptSchema end

"""
Expand Down
73 changes: 62 additions & 11 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ function OpenAI.build_url(provider::AbstractCustomProvider, api::AbstractString)
string(provider.base_url, "/", api)
end
function OpenAI.auth_header(provider::AbstractCustomProvider, api_key::AbstractString)
OpenAI.auth_header(OpenAI.OpenAIProvider(provider.api_key,
OpenAI.auth_header(
OpenAI.OpenAIProvider(provider.api_key,
provider.base_url,
provider.api_version),
api_key)
Expand Down Expand Up @@ -165,6 +166,32 @@ function OpenAI.create_chat(schema::MistralOpenAISchema,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
end
function OpenAI.create_chat(schema::FireworksOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.fireworks.ai/inference/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(FIREWORKS_API_KEY) ? api_key : FIREWORKS_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
end
function OpenAI.create_chat(schema::TogetherOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.together.xyz/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(TOGETHER_API_KEY) ? api_key : TOGETHER_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
end
function OpenAI.create_chat(schema::DatabricksOpenAISchema,
api_key::AbstractString,
model::AbstractString,
Expand Down Expand Up @@ -257,6 +284,28 @@ function OpenAI.create_embeddings(schema::DatabricksOpenAISchema,
input = docs,
kwargs...)
end
function OpenAI.create_embeddings(schema::TogetherOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
url::String = "https://api.together.xyz/v1",
kwargs...)
provider = CustomProvider(;
api_key = isempty(TOGETHER_API_KEY) ? api_key : TOGETHER_API_KEY,
base_url = url)
OpenAI.create_embeddings(provider, docs, model; kwargs...)
end
function OpenAI.create_embeddings(schema::FireworksOpenAISchema,
api_key::AbstractString,
docs,
model::AbstractString;
url::String = "https://api.fireworks.ai/inference/v1",
kwargs...)
provider = CustomProvider(;
api_key = isempty(FIREWORKS_API_KEY) ? api_key : FIREWORKS_API_KEY,
base_url = url)
OpenAI.create_embeddings(provider, docs, model; kwargs...)
end

## Temporary fix -- it will be moved upstream
function OpenAI.create_embeddings(provider::AbstractCustomProvider,
Expand Down Expand Up @@ -316,8 +365,8 @@ function response_to_message(schema::AbstractOpenAISchema,
nothing
end
## calculate cost
tokens_prompt = resp.response[:usage][:prompt_tokens]
tokens_completion = resp.response[:usage][:completion_tokens]
tokens_prompt = get(resp.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens]
tokens_completion = get(resp.response, :usage, Dict(:completion_tokens => 0))[:completion_tokens]
cost = call_cost(tokens_prompt, tokens_completion, model_id)
## build AIMessage object
msg = MSG(;
Expand Down Expand Up @@ -434,7 +483,7 @@ function aigenerate(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_
run_id = Int(rand(Int32)) # remember one run ID
## extract all message
msgs = [response_to_message(prompt_schema, AIMessage, choice, r;
time, model_id, run_id, sample_id = i)
time, model_id, run_id, sample_id = i)
for (i, choice) in enumerate(r.response[:choices])]
## Order by log probability if available
## bigger is better, keep it last
Expand Down Expand Up @@ -537,11 +586,12 @@ function aiembed(prompt_schema::AbstractOpenAISchema,
model_id;
http_kwargs,
api_kwargs...)
tokens_prompt = get(r.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens]
msg = DataMessage(;
content = mapreduce(x -> postprocess(x[:embedding]), hcat, r.response[:data]),
status = Int(r.status),
cost = call_cost(r.response[:usage][:prompt_tokens], 0, model_id),
tokens = (r.response[:usage][:prompt_tokens], 0),
cost = call_cost(tokens_prompt, 0, model_id),
tokens = (tokens_prompt, 0),
elapsed = time)
## Reporting
verbose && @info _report_stats(msg, model_id)
Expand Down Expand Up @@ -773,7 +823,8 @@ aiclassify(:JudgeIsItTrue;
function aiclassify(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE;
choices::AbstractVector{T} = ["true", "false", "unknown"],
api_kwargs::NamedTuple = NamedTuple(),
kwargs...) where {T <: Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}}
kwargs...) where {T <:
Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}}
## Encode the choices and the corresponding prompt
## TODO: maybe check the model provided as well?
choices_prompt, logit_bias, decode_ids = encode_choices(prompt_schema, choices)
Expand Down Expand Up @@ -808,8 +859,8 @@ function response_to_message(schema::AbstractOpenAISchema,
nothing
end
## calculate cost
tokens_prompt = resp.response[:usage][:prompt_tokens]
tokens_completion = resp.response[:usage][:completion_tokens]
tokens_prompt = get(resp.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens]
tokens_completion = get(resp.response, :usage, Dict(:completion_tokens => 0))[:completion_tokens]
cost = call_cost(tokens_prompt, tokens_completion, model_id)
# "Safe" parsing of the response - it still fails if JSON is invalid
content = try
Expand Down Expand Up @@ -987,7 +1038,7 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T
run_id = Int(rand(Int32)) # remember one run ID
## extract all message
msgs = [response_to_message(prompt_schema, DataMessage, choice, r;
return_type, time, model_id, run_id, sample_id = i)
return_type, time, model_id, run_id, sample_id = i)
for (i, choice) in enumerate(r.response[:choices])]
## Order by log probability if available
## bigger is better, keep it last
Expand Down Expand Up @@ -1144,7 +1195,7 @@ function aiscan(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE
run_id = Int(rand(Int32)) # remember one run ID
## extract all message
msgs = [response_to_message(prompt_schema, AIMessage, choice, r;
time, model_id, run_id, sample_id = i)
time, model_id, run_id, sample_id = i)
for (i, choice) in enumerate(r.response[:choices])]
## Order by log probability if available
## bigger is better, keep it last
Expand Down
46 changes: 40 additions & 6 deletions src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ _temp = get(ENV, "GOOGLE_API_KEY", "")
const GOOGLE_API_KEY::String = @load_preference("GOOGLE_API_KEY",
default=_temp);

_temp = get(ENV, "TOGETHER_API_KEY", "")
const TOGETHER_API_KEY::String = @load_preference("TOGETHER_API_KEY",
default=_temp);

_temp = get(ENV, "FIREWORKS_API_KEY", "")
const FIREWORKS_API_KEY::String = @load_preference("FIREWORKS_API_KEY",
default=_temp);

_temp = get(ENV, "LOCAL_SERVER", "")
## Address of the local server
const LOCAL_SERVER::String = @load_preference("LOCAL_SERVER",
Expand Down Expand Up @@ -267,7 +275,8 @@ end
### Model Aliases

# global reference MODEL_ALIASES is defined below
aliases = merge(Dict("gpt3" => "gpt-3.5-turbo",
aliases = merge(
Dict("gpt3" => "gpt-3.5-turbo",
"gpt4" => "gpt-4",
"gpt4v" => "gpt-4-vision-preview", # 4v is for "4 vision"
"gpt4t" => "gpt-4-turbo-preview", # 4t is for "4 turbo"
Expand All @@ -279,11 +288,17 @@ aliases = merge(Dict("gpt3" => "gpt-3.5-turbo",
"oh25" => "openhermes2.5-mistral",
"starling" => "starling-lm",
"local" => "local-server",
"gemini" => "gemini-pro"),
"gemini" => "gemini-pro",
## f-mixtral -> Fireworks.ai Mixtral
"fmixtral" => "accounts/fireworks/models/mixtral-8x7b-instruct",
"firefunction" => "accounts/fireworks/models/firefunction-v1",
## t-mixtral -> Together.ai Mixtral
"tmixtral" => "mistralai/Mixtral-8x7B-Instruct-v0.1"),
## Load aliases from preferences as well
@load_preference("MODEL_ALIASES", default=Dict{String, String}()))

registry = Dict{String, ModelSpec}("gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo",
registry = Dict{String, ModelSpec}(
"gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo",
OpenAISchema(),
0.5e-6,
1.5e-6,
Expand Down Expand Up @@ -389,9 +404,10 @@ registry = Dict{String, ModelSpec}("gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo",
"Mistral AI's hosted model for embeddings."),
"echo" => ModelSpec("echo",
TestEchoOpenAISchema(;
response = Dict(:choices => [
response = Dict(
:choices => [
Dict(:message => Dict(:content => "Hello!"),
:finish_reason => "stop"),
:finish_reason => "stop")
],
:usage => Dict(:total_tokens => 3,
:prompt_tokens => 2,
Expand All @@ -408,7 +424,25 @@ registry = Dict{String, ModelSpec}("gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo",
GoogleSchema(),
0.0, #unknown, expected 1.25e-7
0.0, #unknown, expected 3.75e-7
"Gemini Pro is a LLM from Google. For more information, see [models](https://ai.google.dev/models/gemini)."))
"Gemini Pro is a LLM from Google. For more information, see [models](https://ai.google.dev/models/gemini)."),
"accounts/fireworks/models/mixtral-8x7b-instruct" => ModelSpec(
"accounts/fireworks/models/mixtral-8x7b-instruct",
FireworksOpenAISchema(),
4e-7, #unknown, expected 1.25e-7
1.6e-6, #unknown, expected 3.75e-7
"Mixtral (8x7b) from Mistral, hosted by Fireworks.ai. For more information, see [models](https://fireworks.ai/models/fireworks/mixtral-8x7b-instruct)."),
"accounts/fireworks/models/firefunction-v1" => ModelSpec(
"accounts/fireworks/models/firefunction-v1",
FireworksOpenAISchema(),
0.0, #unknown, expected to be the same as Mixtral
0.0, #unknown, expected to be the same as Mixtral
"Fireworks' open-source function calling model (fine-tuned Mixtral). Useful for `aiextract` calls. For more information, see [models](https://fireworks.ai/models/fireworks/firefunction-v1)."),
"mistralai/Mixtral-8x7B-Instruct-v0.1" => ModelSpec(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
TogetherOpenAISchema(),
6e-7,
6e-7,
"Mixtral (8x7b) from Mistral, hosted by Together.ai. For more information, see [models](https://docs.together.ai/docs/inference-models)."))

### Model Registry Structure
@kwdef mutable struct ModelRegistry
Expand Down

0 comments on commit c811407

Please sign in to comment.