diff --git a/CHANGELOG.md b/CHANGELOG.md index cab889975..faba492f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `run_id` - ID of the AI API call - `sample_id` - ID of the sample in the batch if you requested multiple completions, otherwise `sample_id==nothing` (they will have the same `run_id`) - `finish_reason` - the reason why the AI stopped generating the sequence (eg, "stop", "length") to provide more visibility for the user +- Support for Fireworks.ai and Together.ai providers for fast and easy access to open-source models. Requires environment variables `FIREWORKS_API_KEY` and `TOGETHER_API_KEY` to be set, respectively. See the `?FireworksOpenAISchema` and `?TogetherOpenAISchema` for more information. ### Fixed diff --git a/docs/src/examples/working_with_custom_apis.md b/docs/src/examples/working_with_custom_apis.md index 2a083d778..d4f09fd1a 100644 --- a/docs/src/examples/working_with_custom_apis.md +++ b/docs/src/examples/working_with_custom_apis.md @@ -97,4 +97,66 @@ ai"Say hi to the llama!"dllama You can use `aiembed` as well. -Find more information [here](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html). \ No newline at end of file +Find more information [here](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html). + +## Using Together.ai + +You can also use the Together.ai API with PromptingTools.jl. +It requires you to set ENV variable `TOGETHER_API_KEY`. + +The corresponding schema is `TogetherOpenAISchema`, but we have registered one model for you, so you can use it as usual. +Alias "tmixtral" (T for Together.ai and mixtral for the model name) is already set for you. + +```julia +msg = aigenerate("Say hi"; model="tmixtral") +## [ Info: Tokens: 87 @ Cost: \$0.0001 in 5.1 seconds +## AIMessage("Hello! I'm here to help you. Is there something specific you'd like to know or discuss? I can provide information on a wide range of topics, assist with tasks, and even engage in a friendly conversation. Let me know how I can best assist you today.") +``` + +For embedding a text, use `aiembed`: + +```julia +aiembed(PT.TogetherOpenAISchema(), "embed me"; model="BAAI/bge-large-en-v1.5") +``` +Note: You can register the model with `PT.register_model!` and use it as usual. + +## Using Fireworks.ai + +You can also use the Fireworks.ai API with PromptingTools.jl. +It requires you to set ENV variable `FIREWORKS_API_KEY`. + +The corresponding schema is `FireworksOpenAISchema`, but we have registered one model for you, so you can use it as usual. +Alias "fmixtral" (F for Fireworks.ai and mixtral for the model name) is already set for you. + +```julia +msg = aigenerate("Say hi"; model="fmixtral") +## [ Info: Tokens: 78 @ Cost: \$0.0001 in 0.9 seconds +## AIMessage("Hello! I'm glad you're here. I'm here to help answer any questions you have to the best of my ability. Is there something specific you'd like to know or discuss? I can assist with a wide range of topics, so feel free to ask me anything!") +``` + +In addition, at the time of writing (23rd Feb 2024), Fireworks is providing access to their new _function calling_ model (fine-tuned Mixtral) **for free**. + +Try it with `aiextract` for structured extraction (model is aliased as `firefunction`): + +```julia +""" +Extract the food from the sentence. Extract any provided adjectives for the food as well. + +Example: "I am eating a crunchy bread." -> Food("bread", ["crunchy"]) +""" +struct Food + name::String + adjectives::Union{Nothing,Vector{String}} +end +prompt = "I just ate a delicious and juicy apple." +msg = aiextract(prompt; return_type=Food, model="firefunction") +msg.content +# Output: Food("apple", ["delicious", "juicy"]) +``` + +For embedding a text, use `aiembed`: + +```julia +aiembed(PT.FireworksOpenAISchema(), "embed me"; model="nomic-ai/nomic-embed-text-v1.5") +``` +Note: You can register the model with `PT.register_model!` and use it as usual. diff --git a/docs/src/frequently_asked_questions.md b/docs/src/frequently_asked_questions.md index 228d6a4c2..5f34172c3 100644 --- a/docs/src/frequently_asked_questions.md +++ b/docs/src/frequently_asked_questions.md @@ -166,7 +166,7 @@ There are three ways how you can customize your workflows (especially when you u 2) Register your model and its associated schema (`PT.register_model!(; name="123", schema=PT.OllamaSchema())`). You won't have to specify the schema anymore only the model name. See [Working with Ollama](#working-with-ollama) for more information. 3) Override your default model (`PT.MODEL_CHAT`) and schema (`PT.PROMPT_SCHEMA`). It can be done persistently with Preferences, eg, `PT.set_preferences!("PROMPT_SCHEMA" => "OllamaSchema", "MODEL_CHAT"=>"llama2")`. -## How to have a Multi-turn Conversations? +## How to have Multi-turn Conversations? Let's say you would like to respond back to a model's response. How to do it? diff --git a/src/llm_interface.jl b/src/llm_interface.jl index 639c188b4..5fd913389 100644 --- a/src/llm_interface.jl +++ b/src/llm_interface.jl @@ -146,6 +146,36 @@ Requires two environment variables to be set: """ struct DatabricksOpenAISchema <: AbstractOpenAISchema end +""" + FireworksOpenAISchema + +Schema to call the [Fireworks.ai](https://fireworks.ai/) API. + +Links: +- [Get your API key](https://fireworks.ai/api-keys) +- [API Reference](https://readme.fireworks.ai/reference/createchatcompletion) +- [Available models](https://fireworks.ai/models) + +Requires one environment variables to be set: +- `FIREWORKS_API_KEY`: Your API key +""" +struct FireworksOpenAISchema <: AbstractOpenAISchema end + +""" + TogetherOpenAISchema + +Schema to call the [Together.ai](https://www.together.ai/) API. + +Links: +- [Get your API key](https://api.together.xyz/settings/api-keys) +- [API Reference](https://docs.together.ai/docs/openai-api-compatibility) +- [Available models](https://docs.together.ai/docs/inference-models) + +Requires one environment variables to be set: +- `TOGETHER_API_KEY`: Your API key +""" +struct TogetherOpenAISchema <: AbstractOpenAISchema end + abstract type AbstractOllamaSchema <: AbstractPromptSchema end """ diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 3b35bf48a..764babe0f 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -70,7 +70,8 @@ function OpenAI.build_url(provider::AbstractCustomProvider, api::AbstractString) string(provider.base_url, "/", api) end function OpenAI.auth_header(provider::AbstractCustomProvider, api_key::AbstractString) - OpenAI.auth_header(OpenAI.OpenAIProvider(provider.api_key, + OpenAI.auth_header( + OpenAI.OpenAIProvider(provider.api_key, provider.base_url, provider.api_version), api_key) @@ -165,6 +166,32 @@ function OpenAI.create_chat(schema::MistralOpenAISchema, base_url = url) OpenAI.create_chat(provider, model, conversation; kwargs...) end +function OpenAI.create_chat(schema::FireworksOpenAISchema, + api_key::AbstractString, + model::AbstractString, + conversation; + url::String = "https://api.fireworks.ai/inference/v1", + kwargs...) + # Build the corresponding provider object + # try to override provided api_key because the default is OpenAI key + provider = CustomProvider(; + api_key = isempty(FIREWORKS_API_KEY) ? api_key : FIREWORKS_API_KEY, + base_url = url) + OpenAI.create_chat(provider, model, conversation; kwargs...) +end +function OpenAI.create_chat(schema::TogetherOpenAISchema, + api_key::AbstractString, + model::AbstractString, + conversation; + url::String = "https://api.together.xyz/v1", + kwargs...) + # Build the corresponding provider object + # try to override provided api_key because the default is OpenAI key + provider = CustomProvider(; + api_key = isempty(TOGETHER_API_KEY) ? api_key : TOGETHER_API_KEY, + base_url = url) + OpenAI.create_chat(provider, model, conversation; kwargs...) +end function OpenAI.create_chat(schema::DatabricksOpenAISchema, api_key::AbstractString, model::AbstractString, @@ -257,6 +284,28 @@ function OpenAI.create_embeddings(schema::DatabricksOpenAISchema, input = docs, kwargs...) end +function OpenAI.create_embeddings(schema::TogetherOpenAISchema, + api_key::AbstractString, + docs, + model::AbstractString; + url::String = "https://api.together.xyz/v1", + kwargs...) + provider = CustomProvider(; + api_key = isempty(TOGETHER_API_KEY) ? api_key : TOGETHER_API_KEY, + base_url = url) + OpenAI.create_embeddings(provider, docs, model; kwargs...) +end +function OpenAI.create_embeddings(schema::FireworksOpenAISchema, + api_key::AbstractString, + docs, + model::AbstractString; + url::String = "https://api.fireworks.ai/inference/v1", + kwargs...) + provider = CustomProvider(; + api_key = isempty(FIREWORKS_API_KEY) ? api_key : FIREWORKS_API_KEY, + base_url = url) + OpenAI.create_embeddings(provider, docs, model; kwargs...) +end ## Temporary fix -- it will be moved upstream function OpenAI.create_embeddings(provider::AbstractCustomProvider, @@ -316,8 +365,8 @@ function response_to_message(schema::AbstractOpenAISchema, nothing end ## calculate cost - tokens_prompt = resp.response[:usage][:prompt_tokens] - tokens_completion = resp.response[:usage][:completion_tokens] + tokens_prompt = get(resp.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens] + tokens_completion = get(resp.response, :usage, Dict(:completion_tokens => 0))[:completion_tokens] cost = call_cost(tokens_prompt, tokens_completion, model_id) ## build AIMessage object msg = MSG(; @@ -434,7 +483,7 @@ function aigenerate(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_ run_id = Int(rand(Int32)) # remember one run ID ## extract all message msgs = [response_to_message(prompt_schema, AIMessage, choice, r; - time, model_id, run_id, sample_id = i) + time, model_id, run_id, sample_id = i) for (i, choice) in enumerate(r.response[:choices])] ## Order by log probability if available ## bigger is better, keep it last @@ -537,11 +586,12 @@ function aiembed(prompt_schema::AbstractOpenAISchema, model_id; http_kwargs, api_kwargs...) + tokens_prompt = get(r.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens] msg = DataMessage(; content = mapreduce(x -> postprocess(x[:embedding]), hcat, r.response[:data]), status = Int(r.status), - cost = call_cost(r.response[:usage][:prompt_tokens], 0, model_id), - tokens = (r.response[:usage][:prompt_tokens], 0), + cost = call_cost(tokens_prompt, 0, model_id), + tokens = (tokens_prompt, 0), elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) @@ -773,7 +823,8 @@ aiclassify(:JudgeIsItTrue; function aiclassify(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE; choices::AbstractVector{T} = ["true", "false", "unknown"], api_kwargs::NamedTuple = NamedTuple(), - kwargs...) where {T <: Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}} + kwargs...) where {T <: + Union{AbstractString, Tuple{<:AbstractString, <:AbstractString}}} ## Encode the choices and the corresponding prompt ## TODO: maybe check the model provided as well? choices_prompt, logit_bias, decode_ids = encode_choices(prompt_schema, choices) @@ -808,8 +859,8 @@ function response_to_message(schema::AbstractOpenAISchema, nothing end ## calculate cost - tokens_prompt = resp.response[:usage][:prompt_tokens] - tokens_completion = resp.response[:usage][:completion_tokens] + tokens_prompt = get(resp.response, :usage, Dict(:prompt_tokens => 0))[:prompt_tokens] + tokens_completion = get(resp.response, :usage, Dict(:completion_tokens => 0))[:completion_tokens] cost = call_cost(tokens_prompt, tokens_completion, model_id) # "Safe" parsing of the response - it still fails if JSON is invalid content = try @@ -987,7 +1038,7 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T run_id = Int(rand(Int32)) # remember one run ID ## extract all message msgs = [response_to_message(prompt_schema, DataMessage, choice, r; - return_type, time, model_id, run_id, sample_id = i) + return_type, time, model_id, run_id, sample_id = i) for (i, choice) in enumerate(r.response[:choices])] ## Order by log probability if available ## bigger is better, keep it last @@ -1144,7 +1195,7 @@ function aiscan(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE run_id = Int(rand(Int32)) # remember one run ID ## extract all message msgs = [response_to_message(prompt_schema, AIMessage, choice, r; - time, model_id, run_id, sample_id = i) + time, model_id, run_id, sample_id = i) for (i, choice) in enumerate(r.response[:choices])] ## Order by log probability if available ## bigger is better, keep it last diff --git a/src/user_preferences.jl b/src/user_preferences.jl index c016b6392..2e104236d 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -149,6 +149,14 @@ _temp = get(ENV, "GOOGLE_API_KEY", "") const GOOGLE_API_KEY::String = @load_preference("GOOGLE_API_KEY", default=_temp); +_temp = get(ENV, "TOGETHER_API_KEY", "") +const TOGETHER_API_KEY::String = @load_preference("TOGETHER_API_KEY", + default=_temp); + +_temp = get(ENV, "FIREWORKS_API_KEY", "") +const FIREWORKS_API_KEY::String = @load_preference("FIREWORKS_API_KEY", + default=_temp); + _temp = get(ENV, "LOCAL_SERVER", "") ## Address of the local server const LOCAL_SERVER::String = @load_preference("LOCAL_SERVER", @@ -267,7 +275,8 @@ end ### Model Aliases # global reference MODEL_ALIASES is defined below -aliases = merge(Dict("gpt3" => "gpt-3.5-turbo", +aliases = merge( + Dict("gpt3" => "gpt-3.5-turbo", "gpt4" => "gpt-4", "gpt4v" => "gpt-4-vision-preview", # 4v is for "4 vision" "gpt4t" => "gpt-4-turbo-preview", # 4t is for "4 turbo" @@ -279,11 +288,17 @@ aliases = merge(Dict("gpt3" => "gpt-3.5-turbo", "oh25" => "openhermes2.5-mistral", "starling" => "starling-lm", "local" => "local-server", - "gemini" => "gemini-pro"), + "gemini" => "gemini-pro", + ## f-mixtral -> Fireworks.ai Mixtral + "fmixtral" => "accounts/fireworks/models/mixtral-8x7b-instruct", + "firefunction" => "accounts/fireworks/models/firefunction-v1", + ## t-mixtral -> Together.ai Mixtral + "tmixtral" => "mistralai/Mixtral-8x7B-Instruct-v0.1"), ## Load aliases from preferences as well @load_preference("MODEL_ALIASES", default=Dict{String, String}())) -registry = Dict{String, ModelSpec}("gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo", +registry = Dict{String, ModelSpec}( + "gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo", OpenAISchema(), 0.5e-6, 1.5e-6, @@ -389,9 +404,10 @@ registry = Dict{String, ModelSpec}("gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo", "Mistral AI's hosted model for embeddings."), "echo" => ModelSpec("echo", TestEchoOpenAISchema(; - response = Dict(:choices => [ + response = Dict( + :choices => [ Dict(:message => Dict(:content => "Hello!"), - :finish_reason => "stop"), + :finish_reason => "stop") ], :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, @@ -408,7 +424,25 @@ registry = Dict{String, ModelSpec}("gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo", GoogleSchema(), 0.0, #unknown, expected 1.25e-7 0.0, #unknown, expected 3.75e-7 - "Gemini Pro is a LLM from Google. For more information, see [models](https://ai.google.dev/models/gemini).")) + "Gemini Pro is a LLM from Google. For more information, see [models](https://ai.google.dev/models/gemini)."), + "accounts/fireworks/models/mixtral-8x7b-instruct" => ModelSpec( + "accounts/fireworks/models/mixtral-8x7b-instruct", + FireworksOpenAISchema(), + 4e-7, #unknown, expected 1.25e-7 + 1.6e-6, #unknown, expected 3.75e-7 + "Mixtral (8x7b) from Mistral, hosted by Fireworks.ai. For more information, see [models](https://fireworks.ai/models/fireworks/mixtral-8x7b-instruct)."), + "accounts/fireworks/models/firefunction-v1" => ModelSpec( + "accounts/fireworks/models/firefunction-v1", + FireworksOpenAISchema(), + 0.0, #unknown, expected to be the same as Mixtral + 0.0, #unknown, expected to be the same as Mixtral + "Fireworks' open-source function calling model (fine-tuned Mixtral). Useful for `aiextract` calls. For more information, see [models](https://fireworks.ai/models/fireworks/firefunction-v1)."), + "mistralai/Mixtral-8x7B-Instruct-v0.1" => ModelSpec( + "mistralai/Mixtral-8x7B-Instruct-v0.1", + TogetherOpenAISchema(), + 6e-7, + 6e-7, + "Mixtral (8x7b) from Mistral, hosted by Together.ai. For more information, see [models](https://docs.together.ai/docs/inference-models).")) ### Model Registry Structure @kwdef mutable struct ModelRegistry