From 463a830a518c134b296f2b6aa031a948ed585951 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Thu, 22 Feb 2024 21:06:24 +0000 Subject: [PATCH] Multiple competions (`n`) (#79) --- CHANGELOG.md | 9 +- docs/src/frequently_asked_questions.md | 184 ++++++++++++++- src/Experimental/AgentTools/lazy_types.jl | 2 + src/Experimental/RAGTools/types.jl | 13 ++ src/llm_interface.jl | 13 ++ src/llm_ollama.jl | 18 +- src/llm_ollama_managed.jl | 8 +- src/llm_openai.jl | 223 +++++++++++++++--- src/llm_shared.jl | 13 +- src/messages.jl | 50 ++++- src/precompilation.jl | 5 +- src/user_preferences.jl | 5 +- src/utils.jl | 61 +++-- test/Experimental/RAGTools/evaluation.jl | 36 +-- test/Experimental/RAGTools/generation.jl | 19 +- test/Experimental/RAGTools/preparation.jl | 13 +- test/llm_interface.jl | 29 ++- test/llm_openai.jl | 261 +++++++++++++++++++++- test/llm_shared.jl | 28 +++ test/messages.jl | 4 +- test/utils.jl | 13 +- 21 files changed, 897 insertions(+), 110 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0467d9899..cab889975 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added initial support for Google Gemini models for `aigenerate` (requires environment variable `GOOGLE_API_KEY` and package [GoogleGenAI.jl](https://github.com/tylerjthomas9/GoogleGenAI.jl) to be loaded). - Added a utility to compare any two string sequences (and other iterators)`length_longest_common_subsequence`. It can be used to fuzzy match strings (eg, detecting context/sources in an AI-generated response or fuzzy matching AI response to some preset categories). See the docstring for more information `?length_longest_common_subsequence`. -- Rewrite of `aiclassify` to classify into an arbitrary list of categories (including with descriptions). It's a quick and easy option for "routing" and similar use cases, as it exploits the logit bias trick and outputs only 1 token. Currently only `OpenAISchema` is supported. See `?aiclassify` for more information. +- Rewrite of `aiclassify` to classify into an arbitrary list of categories (including with descriptions). It's a quick and easy option for "routing" and similar use cases, as it exploits the logit bias trick and outputs only 1 token. Currently, only `OpenAISchema` is supported. See `?aiclassify` for more information. +- Initial support for multiple completions in one request for OpenAI-compatible API servers. Set via API kwarg `n=5` and it will request 5 completions in one request, saving the network communication time and paying the prompt tokens only once. It's useful for majority voting, diversity, or challenging agentic workflows. +- Added new fields to `AIMessage` and `DataMessage` types to simplify tracking in complex applications. Added fields: + - `cost` - the cost of the query (summary per call, so count only once if you requested multiple completions in one call) + - `log_prob` - summary log probability of the generated sequence, set API kwarg `logprobs=true` to receive it + - `run_id` - ID of the AI API call + - `sample_id` - ID of the sample in the batch if you requested multiple completions, otherwise `sample_id==nothing` (they will have the same `run_id`) + - `finish_reason` - the reason why the AI stopped generating the sequence (eg, "stop", "length") to provide more visibility for the user ### Fixed diff --git a/docs/src/frequently_asked_questions.md b/docs/src/frequently_asked_questions.md index 3ad8ec9d5..228d6a4c2 100644 --- a/docs/src/frequently_asked_questions.md +++ b/docs/src/frequently_asked_questions.md @@ -39,6 +39,21 @@ Resources: Pro tip: Always set the spending limits! +## Getting an error "ArgumentError: api_key cannot be empty" despite having set `OPENAI_API_KEY`? + +Quick fix: just provide kwarg `api_key` with your key to the `aigenerate` function (and other `ai*` functions). + +This error is thrown when the OpenAI API key is not available in 1) local preferences or 2) environment variables (`ENV["OPENAI_API_KEY"]`). + +First, check if you can access the key by running `ENV["OPENAI_API_KEY"]` in the Julia REPL. If it returns `nothing`, the key is not set. + +If the key is set, but you still get the error, there was a rare bug in earlier versions where if you first precompiled PromptingTools without the API key, it would remember it and "compile away" the `get(ENV,...)` function call. If you're experiencing this bug on the latest version of PromptingTools, please open an issue on GitHub. + +The solution is to force a new precompilation, so you can do any of the below: +1) Force precompilation (run `Pkg.precompile()` in the Julia REPL) +2) Update the PromptingTools package (runs precompilation automatically) +3) Delete your compiled cache in `.julia` DEPOT (usually `.julia/compiled/v1.10/PromptingTools`). You can do it manually in the file explorer or via Julia REPL: `rm("~/.julia/compiled/v1.10/PromptingTools", recursive=true, force=true)` + ## Setting OpenAI Spending Limits OpenAI allows you to set spending limits directly on your account dashboard to prevent unexpected costs. @@ -149,4 +164,171 @@ There are three ways how you can customize your workflows (especially when you u 1) Import the functions/types you need explicitly at the top (eg, `using PromptingTools: OllamaSchema`) 2) Register your model and its associated schema (`PT.register_model!(; name="123", schema=PT.OllamaSchema())`). You won't have to specify the schema anymore only the model name. See [Working with Ollama](#working-with-ollama) for more information. -3) Override your default model (`PT.MODEL_CHAT`) and schema (`PT.PROMPT_SCHEMA`). It can be done persistently with Preferences, eg, `PT.set_preferences!("PROMPT_SCHEMA" => "OllamaSchema", "MODEL_CHAT"=>"llama2")`. \ No newline at end of file +3) Override your default model (`PT.MODEL_CHAT`) and schema (`PT.PROMPT_SCHEMA`). It can be done persistently with Preferences, eg, `PT.set_preferences!("PROMPT_SCHEMA" => "OllamaSchema", "MODEL_CHAT"=>"llama2")`. + +## How to have a Multi-turn Conversations? + +Let's say you would like to respond back to a model's response. How to do it? + +1) With `ai""` macro +The simplest way if you used `ai""` macro, is to send a reply with the `ai!""` macro. It will use the last response as the conversation. +```julia +ai"Hi! I'm John" + +ai!"What's my name?" +# Return: "Your name is John." +``` + +2) With `aigenerate` function +You can use the `conversation` keyword argument to pass the previous conversation (in all `ai*` functions). It will prepend the past `conversation` before sending the new request to the model. + +To get the conversation, set `return_all=true` and store the whole conversation thread (not just the last message) in a variable. Then, use it as a keyword argument in the next call. + +```julia +conversation = aigenerate("Hi! I'm John"; return_all=true) +@info last(conversation) # display the response + +# follow-up (notice that we provide past messages as conversation kwarg +conversation = aigenerate("What's my name?"; return_all=true, conversation) + +## [ Info: Tokens: 50 @ Cost: $0.0 in 1.0 seconds +## 5-element Vector{PromptingTools.AbstractMessage}: +## PromptingTools.SystemMessage("Act as a helpful AI assistant") +## PromptingTools.UserMessage("Hi! I'm John") +## AIMessage("Hello John! How can I assist you today?") +## PromptingTools.UserMessage("What's my name?") +## AIMessage("Your name is John.") +``` +Notice that the last message is the response to the second request, but with `return_all=true` we can see the whole conversation from the beginning. + +## Explain What Happens Under the Hood + +4 Key Concepts/Objects: +- Schemas -> object of type `AbstractPromptSchema` that determines which methods are called and, hence, what providers/APIs are used +- Prompts -> the information you want to convey to the AI model +- Messages -> the basic unit of communication between the user and the AI model (eg, `UserMessage` vs `AIMessage`) +- Prompt Templates -> re-usable "prompts" with placeholders that you can replace with your inputs at the time of making the request + +When you call `aigenerate`, roughly the following happens: `render` -> `UserMessage`(s) -> `render` -> `OpenAI.create_chat` -> ... -> `AIMessage`. + +We'll deep dive into an example in the end. + +### Schemas + +For your "message" to reach an AI model, it needs to be formatted and sent to the right place. + +We leverage the multiple dispatch around the "schemas" to pick the right logic. +All schemas are subtypes of `AbstractPromptSchema` and there are many subtypes, eg, `OpenAISchema <: AbstractOpenAISchema <:AbstractPromptSchema`. + +For example, if you provide `schema = OpenAISchema()`, the system knows that: +- it will have to format any user inputs to OpenAI's "message specification" (a vector of dictionaries, see their API documentation). Function `render(OpenAISchema(),...)` will take care of the rendering. +- it will have to send the message to OpenAI's API. We will use the amazing `OpenAI.jl` package to handle the communication. + +### Prompts + +Prompt is loosely the information you want to convey to the AI model. It can be a question, a statement, or a command. It can have instructions or some context, eg, previous conversation. + +You need to remember that Large Language Models (LLMs) are **stateless**. They don't remember the previous conversation/request, so you need to provide the whole history/context every time (similar to how REST APIs work). + +Prompts that we send to the LLMs are effectively a sequence of messages (`<:AbstractMessage`). + +### Messages + +Messages are the basic unit of communication between the user and the AI model. + +There are 5 main types of messages (`<:AbstractMessage`): + +- `SystemMessage` - this contains information about the "system", eg, how it should behave, format its output, etc. (eg, `You're a world-class Julia programmer. You write brief and concise code.) +- `UserMessage` - the information "from the user", ie, your question/statement/task +- `UserMessageWithImages` - the same as `UserMessage`, but with images (URLs or Base64-encoded images) +- `AIMessage` - the response from the AI model, when the "output" is text +- `DataMessage` - the response from the AI model, when the "output" is data, eg, embeddings with `aiembed` or user-defined structs with `aiextract` + +### Prompt Templates + +We want to have re-usable "prompts", so we provide you with a system to retrieve pre-defined prompts with placeholders (eg, `{{name}}`) that you can replace with your inputs at the time of making the request. + +"AI Templates" as we call them (`AITemplate`) are usually a vector of `SystemMessage` and a `UserMessage` with specific purpose/task. + +For example, the template `:AssistantAsk` is defined loosely as: + +```julia + template = [SystemMessage("You are a world-class AI assistant. Your communication is brief and concise. You're precise and answer only when you're confident in the high quality of your answer."), + UserMessage("# Question\n\n{{ask}}")] +``` + +Notice that we have a placeholder `ask` (`{{ask}}`) that you can replace with your question without having to re-write the generic system instructions. + +When you provide a Symbol (eg, `:AssistantAsk`) to ai* functions, thanks to the multiple dispatch, it recognizes that it's an `AITemplate(:AssistantAsk)` and looks it up. + +You can discover all available templates with `aitemplates("some keyword")` or just see the details of some template `aitemplates(:AssistantAsk)`. + +### Walkthrough Example + +```julia +using PromptingTools +const PT = PromptingTools + +# Let's say this is our ask +msg = aigenerate(:AssistantAsk; ask="What is the capital of France?") + +# it is effectively the same as: +msg = aigenerate(PT.OpenAISchema(), PT.AITemplate(:AssistantAsk); ask="What is the capital of France?", model="gpt3t") +``` + +There is no `model` provided, so we use the default `PT.MODEL_CHAT` (effectively GPT3.5-Turbo). Then we look it up in `PT.MDOEL_REGISTRY` and use the associated schema for it (`OpenAISchema` in this case). + +The next step is to render the template, replace the placeholders and render it for the OpenAI model. + +```julia +# Let's remember out schema +schema = PT.OpenAISchema() +ask = "What is the capital of France?" +``` + +First, we obtain the template (no placeholder replacement yet) and "expand it" +```julia +template_rendered = PT.render(schema, AITemplate(:AssistantAsk); ask) +``` + +```plaintext +2-element Vector{PromptingTools.AbstractChatMessage}: + PromptingTools.SystemMessage("You are a world-class AI assistant. Your communication is brief and concise. You're precise and answer only when you're confident in the high quality of your answer.") + PromptingTools.UserMessage{String}("# Question\n\n{{ask}}", [:ask], :usermessage) +``` + +Second, we replace the placeholders +```julia +rendered_for_api = PT.render(schema, template_rendered; ask) +``` + +```plaintext +2-element Vector{Dict{String, Any}}: + Dict("role" => "system", "content" => "You are a world-class AI assistant. Your communication is brief and concise. You're precise and answer only when you're confident in the high quality of your answer.") + Dict("role" => "user", "content" => "# Question\n\nWhat is the capital of France?") +``` + +Notice that the placeholders are only replaced in the second step. The final output here is a vector of messages with "role" and "content" keys, which is the format required by the OpenAI API. + +As a side note, under the hood, the second step is done in two steps: + +- replace the placeholders `messages_rendered = PT.render(PT.NoSchema(), template_rendered; ask)` -> returns a vector of Messages! +- then, we convert the messages to the format required by the provider/schema `PT.render(schema, messages_rendered)` -> returns the OpenAI formatted messages + + +Next, we send the above `rendered_for_api` to the OpenAI API and get the response back. + +```julia +using OpenAI +OpenAI.create_chat(api_key, model, rendered_for_api) +``` + +The last step is to take the JSON response from the API and convert it to the `AIMessage` object. + +```julia +# simplification for educational purposes +msg = AIMessage(; content = r.response[:choices][1][:message][:content]) +``` +In practice, there are more fields we extract, so we define a utility for it: `PT.response_to_message`. Especially, since with parameter `n`, you can request multiple AI responses at once, so we want to re-use our response processing logic. + +That's it! I hope you've learned something new about how PromptingTools.jl works under the hood. \ No newline at end of file diff --git a/src/Experimental/AgentTools/lazy_types.jl b/src/Experimental/AgentTools/lazy_types.jl index 2984ae1ac..6201080e4 100644 --- a/src/Experimental/AgentTools/lazy_types.jl +++ b/src/Experimental/AgentTools/lazy_types.jl @@ -62,6 +62,8 @@ This can be used to "reply" to previous message / continue the stored conversati success::Union{Nothing, Bool} = nothing error::Union{Nothing, Exception} = nothing end +## main sample +## samples function AICall(func::F, args...; kwargs...) where {F <: Function} @assert length(args)<=2 "AICall takes at most 2 positional arguments (provided: $(length(args)))" diff --git a/src/Experimental/RAGTools/types.jl b/src/Experimental/RAGTools/types.jl index dab4e78ba..863a25abb 100644 --- a/src/Experimental/RAGTools/types.jl +++ b/src/Experimental/RAGTools/types.jl @@ -8,6 +8,19 @@ abstract type AbstractChunkIndex <: AbstractDocumentIndex end # More advanced index would be: HybridChunkIndex # Stores document chunks and their embeddings +""" + ChunkIndex + +Main struct for storing document chunks and their embeddings. It also stores tags and sources for each chunk. + +# Fields +- `id::Symbol`: unique identifier of each index (to ensure we're using the right index with `CandidateChunks`) +- `chunks::Vector{<:AbstractString}`: underlying document chunks / snippets +- `embeddings::Union{Nothing, Matrix{<:Real}}`: for semantic search +- `tags::Union{Nothing, AbstractMatrix{<:Bool}}`: for exact search, filtering, etc. This is often a sparse matrix indicating which chunks have the given `tag` (see `tag_vocab` for the position lookup) +- `tags_vocab::Union{Nothing, Vector{<:AbstractString}}`: vocabulary for the `tags` matrix (each column in `tags` is one item in `tags_vocab` and rows are the chunks) +- `sources::Vector{<:AbstractString}`: sources of the chunks +""" @kwdef struct ChunkIndex{ T1 <: AbstractString, T2 <: Union{Nothing, Matrix{<:Real}}, diff --git a/src/llm_interface.jl b/src/llm_interface.jl index 04f24a56c..639c188b4 100644 --- a/src/llm_interface.jl +++ b/src/llm_interface.jl @@ -250,3 +250,16 @@ function aiscan(prompt; model = MODEL_CHAT, kwargs...) schema = get(MODEL_REGISTRY, model, (; schema = PROMPT_SCHEMA)).schema aiscan(schema, prompt; model, kwargs...) end + +"Utility to facilitate unwrapping of HTTP response to a message type `MSG` provided. Designed to handle multi-sample completions." +function response_to_message(schema::AbstractPromptSchema, + MSG::Type{T}, + choice, + resp; + return_type = nothing, + model_id::AbstractString = "", + time::Float64 = 0.0, + run_id::Integer = rand(Int16), + sample_id::Union{Nothing, Integer} = nothing) where {T} + throw(ArgumentError("Response unwrapping not implemented for $(typeof(schema)) and $MSG")) +end diff --git a/src/llm_ollama.jl b/src/llm_ollama.jl index f03c0c2ff..ae9b95a86 100644 --- a/src/llm_ollama.jl +++ b/src/llm_ollama.jl @@ -2,6 +2,8 @@ # - llm_olama.jl works by providing messages format to /api/chat # - llm_managed_olama.jl works by providing 1 system prompt and 1 user prompt /api/generate # +# TODO: switch to OpenAI-compatible endpoint! +# ## Schema dedicated to [Ollama's models](https://ollama.ai/), which also managed the prompt templates # ## Rendering of converation history for the Ollama API (similar to OpenAI but not for the images) @@ -157,10 +159,14 @@ function aigenerate(prompt_schema::AbstractOllamaSchema, prompt::ALLOWED_PROMPT_ http_kwargs, api_kwargs...) + tokens_prompt = get(resp.response, :prompt_eval_count, 0) + tokens_completion = get(resp.response, :eval_count, 0) msg = AIMessage(; content = resp.response[:message][:content] |> strip, status = Int(resp.status), - tokens = (get(resp.response, :prompt_eval_count, 0), - get(resp.response, :eval_count, 0)), + cost = call_cost(tokens_prompt, tokens_completion, model_id), + ## not coming through yet anyway + ## finish_reason = get(resp.response, :finish_reason, nothing), + tokens = (tokens_prompt, tokens_completion), elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) @@ -184,7 +190,7 @@ function aiembed(prompt_schema::AbstractOllamaSchema, args...; kwargs...) end """ -aiscan([prompt_schema::AbstractOllamaSchema,] prompt::ALLOWED_PROMPT_TYPE; + aiscan([prompt_schema::AbstractOllamaSchema,] prompt::ALLOWED_PROMPT_TYPE; image_url::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing, image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing, attach_to_latest::Bool = true, @@ -314,10 +320,12 @@ function aiscan(prompt_schema::AbstractOllamaSchema, prompt::ALLOWED_PROMPT_TYPE system = nothing, messages = conv_rendered, endpoint = "chat", model = model_id, http_kwargs, api_kwargs...) + tokens_prompt = get(resp.response, :prompt_eval_count, 0) + tokens_completion = get(resp.response, :eval_count, 0) msg = AIMessage(; content = resp.response[:message][:content] |> strip, status = Int(resp.status), - tokens = (get(resp.response, :prompt_eval_count, 0), - get(resp.response, :eval_count, 0)), + cost = call_cost(tokens_prompt, tokens_completion, model_id), + tokens = (tokens_prompt, tokens_completion), elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) diff --git a/src/llm_ollama_managed.jl b/src/llm_ollama_managed.jl index 23286c3f6..647a53cec 100644 --- a/src/llm_ollama_managed.jl +++ b/src/llm_ollama_managed.jl @@ -214,10 +214,12 @@ function aigenerate(prompt_schema::AbstractOllamaManagedSchema, prompt::ALLOWED_ time = @elapsed resp = ollama_api(prompt_schema, conv_rendered.prompt; conv_rendered.system, endpoint = "generate", model = model_id, http_kwargs, api_kwargs...) + tokens_prompt = get(resp.response, :prompt_eval_count, 0) + tokens_completion = get(resp.response, :eval_count, 0) msg = AIMessage(; content = resp.response[:response] |> strip, status = Int(resp.status), - tokens = (get(resp.response, :prompt_eval_count, 0), - get(resp.response, :eval_count, 0)), + cost = call_cost(tokens_prompt, tokens_completion, model_id), + tokens = (tokens_prompt, tokens_completion), elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) @@ -326,6 +328,7 @@ function aiembed(prompt_schema::AbstractOllamaManagedSchema, msg = DataMessage(; content = postprocess(resp.response[:embedding]), status = Int(resp.status), + cost = call_cost(0, 0, model_id), tokens = (0, 0), # token counts are not provided for embeddings elapsed = time) ## Reporting @@ -356,6 +359,7 @@ function aiembed(prompt_schema::AbstractOllamaManagedSchema, msg = DataMessage(; content = mapreduce(x -> x.content, hcat, messages), status = mapreduce(x -> x.status, max, messages), + cost = mapreduce(x -> x.cost, +, messages), tokens = (0, 0),# not tracked for embeddings in Ollama elapsed = sum(x -> x.elapsed, messages)) ## Reporting diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 5fa641227..3b35bf48a 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -273,6 +273,66 @@ function OpenAI.create_embeddings(provider::AbstractCustomProvider, kwargs...) end +""" + response_to_message(schema::AbstractOpenAISchema, + MSG::Type{AIMessage}, + choice, + resp; + model_id::AbstractString = "", + time::Float64 = 0.0, + run_id::Integer = rand(Int16), + sample_id::Union{Nothing, Integer} = nothing) + +Utility to facilitate unwrapping of HTTP response to a message type `MSG` provided for OpenAI-like responses + +Note: Extracts `finish_reason` and `log_prob` if available in the response. + +# Arguments +- `schema::AbstractOpenAISchema`: The schema for the prompt. +- `MSG::Type{AIMessage}`: The message type to be returned. +- `choice`: The choice from the response (eg, one of the completions). +- `resp`: The response from the OpenAI API. +- `model_id::AbstractString`: The model ID to use for generating the response. Defaults to an empty string. +- `time::Float64`: The elapsed time for the response. Defaults to `0.0`. +- `run_id::Integer`: The run ID for the response. Defaults to a random integer. +- `sample_id::Union{Nothing, Integer}`: The sample ID for the response (if there are multiple completions). Defaults to `nothing`. +""" +function response_to_message(schema::AbstractOpenAISchema, + MSG::Type{AIMessage}, + choice, + resp; + model_id::AbstractString = "", + time::Float64 = 0.0, + run_id::Int = Int(rand(Int32)), + sample_id::Union{Nothing, Integer} = nothing) + ## extract sum log probability + has_log_prob = haskey(choice, :logprobs) && + !isnothing(get(choice, :logprobs, nothing)) && + haskey(choice[:logprobs], :content) && + !isnothing(choice[:logprobs][:content]) + log_prob = if has_log_prob + sum([get(c, :logprob, 0.0) for c in choice[:logprobs][:content]]) + else + nothing + end + ## calculate cost + tokens_prompt = resp.response[:usage][:prompt_tokens] + tokens_completion = resp.response[:usage][:completion_tokens] + cost = call_cost(tokens_prompt, tokens_completion, model_id) + ## build AIMessage object + msg = MSG(; + content = choice[:message][:content] |> strip, + status = Int(resp.status), + cost, + run_id, + sample_id, + log_prob, + finish_reason = get(choice, :finish_reason, nothing), + tokens = (tokens_prompt, + tokens_completion), + elapsed = time) +end + ## User-Facing API """ aigenerate(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE; @@ -296,7 +356,11 @@ Generate an AI response based on a given prompt using the OpenAI API. - `dry_run::Bool=false`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`). - `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. If not provided, it is initialized as an empty vector. - `http_kwargs`: A named tuple of HTTP keyword arguments. -- `api_kwargs`: A named tuple of API keyword arguments. +- `api_kwargs`: A named tuple of API keyword arguments. Useful parameters include: + - `temperature`: A float representing the temperature for sampling (ie, the amount of "creativity"). Often defaults to `0.7`. + - `logprobs`: A boolean indicating whether to return log probabilities for each token. Defaults to `false`. + - `n`: An integer representing the number of completions to generate at once (if supported). + - `stop`: A vector of strings representing the stop conditions for the conversation. Defaults to an empty vector. - `kwargs`: Prompt variables to be used to fill the prompt/template # Returns @@ -365,12 +429,26 @@ function aigenerate(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_ conv_rendered; http_kwargs, api_kwargs...) - msg = AIMessage(; - content = r.response[:choices][begin][:message][:content] |> strip, - status = Int(r.status), - tokens = (r.response[:usage][:prompt_tokens], - r.response[:usage][:completion_tokens]), - elapsed = time) + ## Process one of more samples returned + msg = if length(r.response[:choices]) > 1 + run_id = Int(rand(Int32)) # remember one run ID + ## extract all message + msgs = [response_to_message(prompt_schema, AIMessage, choice, r; + time, model_id, run_id, sample_id = i) + for (i, choice) in enumerate(r.response[:choices])] + ## Order by log probability if available + ## bigger is better, keep it last + if all(x -> !isnothing(x.log_prob), msgs) + sort(msgs, by = x -> x.log_prob) + else + msgs + end + else + ## only 1 sample / 1 completion + choice = r.response[:choices][begin] + response_to_message(prompt_schema, AIMessage, choice, r; + time, model_id) + end ## Reporting verbose && @info _report_stats(msg, model_id) else @@ -454,7 +532,6 @@ function aiembed(prompt_schema::AbstractOpenAISchema, global MODEL_ALIASES ## Find the unique ID for the model alias provided model_id = get(MODEL_ALIASES, model, model) - time = @elapsed r = create_embeddings(prompt_schema, api_key, doc_or_docs, model_id; @@ -463,6 +540,7 @@ function aiembed(prompt_schema::AbstractOpenAISchema, msg = DataMessage(; content = mapreduce(x -> postprocess(x[:embedding]), hcat, r.response[:data]), status = Int(r.status), + cost = call_cost(r.response[:usage][:prompt_tokens], 0, model_id), tokens = (r.response[:usage][:prompt_tokens], 0), elapsed = time) ## Reporting @@ -585,8 +663,15 @@ function decode_choices(schema::TestEchoOpenAISchema, end function decode_choices(schema::OpenAISchema, choices, conv::AbstractVector; kwargs...) - if length(conv) > 0 && last(conv) isa AIMessage - conv[end] = decode_choices(schema, choices, last(conv)) + if length(conv) > 0 && last(conv) isa AIMessage && hasproperty(last(conv), :run_id) + ## if it is a multi-sample response, + ## Remember its run ID and convert all samples in that run + run_id = last(conv).run_id + for i in eachindex(conv) + if conv[i].run_id == run_id + conv[i] = decode_choices(schema, choices, conv[i]) + end + end end return conv end @@ -615,7 +700,8 @@ function decode_choices(schema::OpenAISchema, ## failed decoding content = nothing end - return AIMessage(; content, msg.status, msg.tokens, msg.elapsed) + ## create a new object with all the same fields except for content + return AIMessage(; [f => getfield(msg, f) for f in fieldnames(typeof(msg))]..., content) end """ @@ -701,6 +787,53 @@ function aiclassify(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_ return decode_choices(prompt_schema, decode_ids, msg_or_conv) end +function response_to_message(schema::AbstractOpenAISchema, + MSG::Type{DataMessage}, + choice, + resp; + return_type = nothing, + model_id::AbstractString = "", + time::Float64 = 0.0, + run_id::Int = Int(rand(Int32)), + sample_id::Union{Nothing, Integer} = nothing) + @assert !isnothing(return_type) "You must provide a return_type for DataMessage construction" + ## extract sum log probability + has_log_prob = haskey(choice, :logprobs) && + !isnothing(get(choice, :logprobs, nothing)) && + haskey(choice[:logprobs], :content) && + !isnothing(choice[:logprobs][:content]) + log_prob = if has_log_prob + sum([get(c, :logprob, 0.0) for c in choice[:logprobs][:content]]) + else + nothing + end + ## calculate cost + tokens_prompt = resp.response[:usage][:prompt_tokens] + tokens_completion = resp.response[:usage][:completion_tokens] + cost = call_cost(tokens_prompt, tokens_completion, model_id) + # "Safe" parsing of the response - it still fails if JSON is invalid + content = try + choice[:message][:tool_calls][1][:function][:arguments] |> + x -> JSON3.read(x, return_type) + catch e + @warn "There was an error parsing the response: $e. Using the raw response instead." + choice[:message][:tool_calls][1][:function][:arguments] |> + JSON3.read |> copy + end + ## build DataMessage object + msg = MSG(; + content = content, + status = Int(resp.status), + cost, + run_id, + sample_id, + log_prob, + finish_reason = get(choice, :finish_reason, nothing), + tokens = (tokens_prompt, + tokens_completion), + elapsed = time) +end + """ aiextract([prompt_schema::AbstractOpenAISchema,] prompt::ALLOWED_PROMPT_TYPE; return_type::Type, @@ -833,10 +966,12 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T ## global MODEL_ALIASES ## Function calling specifics - functions = [function_call_signature(return_type)] - function_call = Dict(:name => only(functions)["name"]) + tools = [Dict(:type => "function", :function => function_call_signature(return_type))] + ## force our function to be used + tool_choice = Dict(:type => "function", + :function => Dict(:name => only(tools)[:function]["name"])) ## Add the function call signature to the api_kwargs - api_kwargs = merge(api_kwargs, (; functions, function_call)) + api_kwargs = merge(api_kwargs, (; tools, tool_choice)) ## Find the unique ID for the model alias provided model_id = get(MODEL_ALIASES, model, model) conv_rendered = render(prompt_schema, prompt; conversation, kwargs...) @@ -847,20 +982,26 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T conv_rendered; http_kwargs, api_kwargs...) - # "Safe" parsing of the response - it still fails if JSON is invalid - content = try - r.response[:choices][begin][:message][:function_call][:arguments] |> - x -> JSON3.read(x, return_type) - catch e - @warn "There was an error parsing the response: $e. Using the raw response instead." - r.response[:choices][begin][:message][:function_call][:arguments] |> - JSON3.read |> copy + ## Process one of more samples returned + msg = if length(r.response[:choices]) > 1 + run_id = Int(rand(Int32)) # remember one run ID + ## extract all message + msgs = [response_to_message(prompt_schema, DataMessage, choice, r; + return_type, time, model_id, run_id, sample_id = i) + for (i, choice) in enumerate(r.response[:choices])] + ## Order by log probability if available + ## bigger is better, keep it last + if all(x -> !isnothing(x.log_prob), msgs) + sort(msgs, by = x -> x.log_prob) + else + msgs + end + else + ## only 1 sample / 1 completion + choice = r.response[:choices][begin] + response_to_message(prompt_schema, DataMessage, choice, r; + return_type, time, model_id) end - msg = DataMessage(; content, - status = Int(r.status), - tokens = (r.response[:usage][:prompt_tokens], - r.response[:usage][:completion_tokens]), - elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) else @@ -879,7 +1020,7 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T end """ -aiscan([prompt_schema::AbstractOpenAISchema,] prompt::ALLOWED_PROMPT_TYPE; + aiscan([prompt_schema::AbstractOpenAISchema,] prompt::ALLOWED_PROMPT_TYPE; image_url::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing, image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing, image_detail::AbstractString = "auto", @@ -998,12 +1139,26 @@ function aiscan(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE conv_rendered; http_kwargs, api_kwargs...) - msg = AIMessage(; - content = r.response[:choices][begin][:message][:content] |> strip, - status = Int(r.status), - tokens = (r.response[:usage][:prompt_tokens], - r.response[:usage][:completion_tokens]), - elapsed = time) + ## Process one of more samples returned + msg = if length(r.response[:choices]) > 1 + run_id = Int(rand(Int32)) # remember one run ID + ## extract all message + msgs = [response_to_message(prompt_schema, AIMessage, choice, r; + time, model_id, run_id, sample_id = i) + for (i, choice) in enumerate(r.response[:choices])] + ## Order by log probability if available + ## bigger is better, keep it last + if all(x -> !isnothing(x.log_prob), msgs) + sort(msgs, by = x -> x.log_prob) + else + msgs + end + else + ## only 1 sample / 1 completion + choice = r.response[:choices][begin] + response_to_message(prompt_schema, AIMessage, choice, r; + time, model_id) + end ## Reporting verbose && @info _report_stats(msg, model_id) else diff --git a/src/llm_shared.jl b/src/llm_shared.jl index fd6401187..a611c5929 100644 --- a/src/llm_shared.jl +++ b/src/llm_shared.jl @@ -65,7 +65,7 @@ end """ finalize_outputs(prompt::ALLOWED_PROMPT_TYPE, conv_rendered::Any, - msg::Union{Nothing, AbstractMessage}; + msg::Union{Nothing, AbstractMessage, AbstractVector{<:AbstractMessage}}; return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], @@ -81,7 +81,7 @@ Finalizes the outputs of the ai* functions by either returning the conversation - `kwargs...`: Variables to replace in the prompt template. """ function finalize_outputs(prompt::ALLOWED_PROMPT_TYPE, conv_rendered::Any, - msg::Union{Nothing, AbstractMessage}; + msg::Union{Nothing, AbstractMessage, AbstractVector{<:AbstractMessage}}; return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], @@ -92,7 +92,12 @@ function finalize_outputs(prompt::ALLOWED_PROMPT_TYPE, conv_rendered::Any, # This is a duplication of work, as we already have the rendered messages in conv_rendered, # but we prioritize the user's experience over performance here (ie, render(OpenAISchema,msgs) does everything under the hood) output = render(NoSchema(), prompt; conversation, kwargs...) - push!(output, msg) + if msg isa AbstractVector + ## handle multiple messages (multi-sample) + append!(output, msg) + else + push!(output, msg) + end else output = conv_rendered end @@ -113,4 +118,4 @@ function decode_choices(schema::AbstractPromptSchema, choices, conv; kwargs...) end function decode_choices(schema::AbstractPromptSchema, choices, conv::Nothing; kwargs...) nothing -end +end \ No newline at end of file diff --git a/src/messages.jl b/src/messages.jl index 63d153b19..cda980462 100644 --- a/src/messages.jl +++ b/src/messages.jl @@ -61,18 +61,64 @@ function UserMessageWithImages(content::T, image_url::Vector{<:AbstractString}, @assert length(not_allowed_kwargs)==0 "Error: Some placeholders are invalid, as they are reserved for `ai*` functions. Change: $(join(not_allowed_kwargs,","))" return UserMessageWithImages{T}(content, string.(image_url), variables, type) end + +""" + AIMessage + +A message type for AI-generated text-based responses. +Returned by `aigenerate`, `aiclassify`, and `aiscan` functions. + +# Fields +- `content::Union{AbstractString, Nothing}`: The content of the message. +- `status::Union{Int, Nothing}`: The status of the message from the API. +- `tokens::Tuple{Int, Int}`: The number of tokens used (prompt,completion). +- `elapsed::Float64`: The time taken to generate the response in seconds. +- `cost::Union{Nothing, Float64}`: The cost of the API call (calculated with information from `MODEL_REGISTRY`). +- `log_prob::Union{Nothing, Float64}`: The log probability of the response. +- `finish_reason::Union{Nothing, String}`: The reason the response was finished. +- `run_id::Union{Nothing, Int}`: The unique ID of the run. +- `sample_id::Union{Nothing, Int}`: The unique ID of the sample (if multiple samples are generated, they will all have the same `run_id`). +""" Base.@kwdef struct AIMessage{T <: Union{AbstractString, Nothing}} <: AbstractChatMessage content::T = nothing status::Union{Int, Nothing} = nothing tokens::Tuple{Int, Int} = (-1, -1) elapsed::Float64 = -1.0 + cost::Union{Nothing, Float64} = nothing + log_prob::Union{Nothing, Float64} = nothing + finish_reason::Union{Nothing, String} = nothing + run_id::Union{Nothing, Int} = Int(rand(Int16)) + sample_id::Union{Nothing, Int} = nothing _type::Symbol = :aimessage end + +""" + DataMessage + +A message type for AI-generated data-based responses, ie, different `content` than text. +Returned by `aiextract`, and `aiextract` functions. + +# Fields +- `content::Union{AbstractString, Nothing}`: The content of the message. +- `status::Union{Int, Nothing}`: The status of the message from the API. +- `tokens::Tuple{Int, Int}`: The number of tokens used (prompt,completion). +- `elapsed::Float64`: The time taken to generate the response in seconds. +- `cost::Union{Nothing, Float64}`: The cost of the API call (calculated with information from `MODEL_REGISTRY`). +- `log_prob::Union{Nothing, Float64}`: The log probability of the response. +- `finish_reason::Union{Nothing, String}`: The reason the response was finished. +- `run_id::Union{Nothing, Int}`: The unique ID of the run. +- `sample_id::Union{Nothing, Int}`: The unique ID of the sample (if multiple samples are generated, they will all have the same `run_id`). +""" Base.@kwdef struct DataMessage{T <: Any} <: AbstractDataMessage content::T status::Union{Int, Nothing} = nothing tokens::Tuple{Int, Int} = (-1, -1) elapsed::Float64 = -1.0 + cost::Union{Nothing, Float64} = nothing + log_prob::Union{Nothing, Float64} = nothing + finish_reason::Union{Nothing, String} = nothing + run_id::Union{Nothing, Int} = Int(rand(Int16)) + sample_id::Union{Nothing, Int} = nothing _type::Symbol = :datamessage end @@ -83,11 +129,13 @@ end isusermessage(m::AbstractMessage) = m isa UserMessage issystemmessage(m::AbstractMessage) = m isa SystemMessage isdatamessage(m::AbstractMessage) = m isa DataMessage +isaimessage(m::AbstractMessage) = m isa AIMessage # equality check for testing, only equal if all fields are equal and type is the same Base.var"=="(m1::AbstractMessage, m2::AbstractMessage) = false function Base.var"=="(m1::T, m2::T) where {T <: AbstractMessage} - all([getproperty(m1, f) == getproperty(m2, f) for f in fieldnames(T)]) + ## except for run_id, that's random and not important for content comparison + all([getproperty(m1, f) == getproperty(m2, f) for f in fieldnames(T) if f != :run_id]) end ## Vision Models -- Constructor and Conversion diff --git a/src/precompilation.jl b/src/precompilation.jl index a1d823cc6..2a7fc9703 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -8,7 +8,10 @@ load_templates!(); # API Calls prep mock_response = Dict(:choices => [ Dict(:message => Dict(:content => "Hello!", - :function_call => Dict(:arguments => JSON3.write(Dict(:x => 1))))), + :tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(Dict(:x => 1)))), + ]), + :finish_reason => "stop"), ], :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) schema = TestEchoOpenAISchema(; response = mock_response, status = 200) diff --git a/src/user_preferences.jl b/src/user_preferences.jl index 733dc0495..c016b6392 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -389,7 +389,10 @@ registry = Dict{String, ModelSpec}("gpt-3.5-turbo" => ModelSpec("gpt-3.5-turbo", "Mistral AI's hosted model for embeddings."), "echo" => ModelSpec("echo", TestEchoOpenAISchema(; - response = Dict(:choices => [Dict(:message => Dict(:content => "Hello!"))], + response = Dict(:choices => [ + Dict(:message => Dict(:content => "Hello!"), + :finish_reason => "stop"), + ], :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)), status = 200), diff --git a/src/utils.jl b/src/utils.jl index 3170056c8..65cb94bd3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -244,15 +244,20 @@ function _extract_handlebar_variables(vect::Vector{Dict{String, <:AbstractString end """ - call_cost(msg, model::String; - cost_of_token_prompt::Number = default_prompt_cost, - cost_of_token_generation::Number = default_generation_cost) -> Number + call_cost(prompt_tokens::Int, completion_tokens::Int, model::String; + cost_of_token_prompt::Number = get(MODEL_REGISTRY, + model, + (; cost_of_token_prompt = 0.0)).cost_of_token_prompt, + cost_of_token_generation::Number = get(MODEL_REGISTRY, model, + (; cost_of_token_generation = 0.0)).cost_of_token_generation) + + call_cost(msg, model::String) Calculate the cost of a call based on the number of tokens in the message and the cost per token. # Arguments -- `msg`: The message object, which should contain a `tokens` field - with two elements: [number_of_prompt_tokens, number_of_generation_tokens]. +- `prompt_tokens::Int`: The number of tokens used in the prompt. +- `completion_tokens::Int`: The number of tokens used in the completion. - `model::String`: The name of the model to use for determining token costs. If the model is not found in `MODEL_REGISTRY`, default costs are used. - `cost_of_token_prompt::Number`: The cost per prompt token. Defaults to the cost in `MODEL_REGISTRY` @@ -271,30 +276,49 @@ MODEL_REGISTRY = Dict( "model2" => (cost_of_token_prompt = 0.07, cost_of_token_generation = 0.02) ) -msg1 = AIMessage([10, 20]) # 10 prompt tokens, 20 generation tokens +cost1 = call_cost(10, 20, "model1") + +# from message +msg1 = AIMessage(;tokens=[10, 20]) # 10 prompt tokens, 20 generation tokens cost1 = call_cost(msg1, "model1") # cost1 = 10 * 0.05 + 20 * 0.10 = 2.5 -msg2 = DataMessage([15, 30]) # 15 prompt tokens, 30 generation tokens -cost2 = call_cost(msg2, "model2") -# cost2 = 15 * 0.07 + 30 * 0.02 = 1.35 - # Using custom token costs -msg3 = AIMessage([5, 10]) -cost3 = call_cost(msg3, "model3", cost_of_token_prompt = 0.08, cost_of_token_generation = 0.12) -# cost3 = 5 * 0.08 + 10 * 0.12 = 1.6 +cost2 = call_cost(10, 20, "model3"; cost_of_token_prompt = 0.08, cost_of_token_generation = 0.12) +# cost2 = 10 * 0.08 + 20 * 0.12 = 3.2 ``` """ -function call_cost(msg, model::String; +function call_cost(prompt_tokens::Int, completion_tokens::Int, model::String; cost_of_token_prompt::Number = get(MODEL_REGISTRY, model, (; cost_of_token_prompt = 0.0)).cost_of_token_prompt, cost_of_token_generation::Number = get(MODEL_REGISTRY, model, (; cost_of_token_generation = 0.0)).cost_of_token_generation) - cost = msg.tokens[1] * cost_of_token_prompt + - msg.tokens[2] * cost_of_token_generation + cost = prompt_tokens * cost_of_token_prompt + + completion_tokens * cost_of_token_generation + return cost +end +function call_cost(msg, model::String) + cost = if !isnothing(msg.cost) + msg.cost + else + call_cost(msg.tokens[1], msg.tokens[2], model) + end return cost end +## dispatch for array -> take unique messages only (eg, for multiple samples we count only once) +function call_cost(conv::AbstractVector, model::String) + sum_ = 0.0 + visited_runs = Set{Int}() + for msg in conv + if isnothing(msg.run_id) || (msg.run_id ∉ visited_runs) + sum_ += call_cost(msg, model) + push!(visited_runs, msg.run_id) + end + end + return sum_ +end + # helper to produce summary message of how many tokens were used and for how much function _report_stats(msg, model::String) @@ -302,6 +326,11 @@ function _report_stats(msg, cost_str = iszero(cost) ? "" : " @ Cost: \$$(round(cost; digits=4))" return "Tokens: $(sum(msg.tokens))$(cost_str) in $(round(msg.elapsed;digits=1)) seconds" end +## dispatch for array -> take last message +function _report_stats(msg::AbstractVector, + model::String) + _report_stats(last(msg), model) +end # Loads and encodes the provided image path as a base64 string function _encode_local_image(image_path::AbstractString; base64_only::Bool = false) @assert isfile(image_path) "`image_path` must be a valid path to an image file. File: $image_path not found." diff --git a/test/Experimental/RAGTools/evaluation.jl b/test/Experimental/RAGTools/evaluation.jl index 1e5e79f03..9c48eba7b 100644 --- a/test/Experimental/RAGTools/evaluation.jl +++ b/test/Experimental/RAGTools/evaluation.jl @@ -87,7 +87,9 @@ end if content[:model] == "mock-gen" user_msg = last(content[:messages]) - response = Dict(:choices => [Dict(:message => user_msg)], + response = Dict(:choices => [ + Dict(:message => user_msg, :finish_reason => "stop"), + ], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), @@ -101,10 +103,11 @@ end elseif content[:model] == "mock-meta" user_msg = last(content[:messages]) response = Dict(:choices => [ - Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(MaybeMetadataItems([ - MetadataItem("yes", "category"), - ]))))), - ], + Dict(:finish_reason => "stop", + :message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(MaybeMetadataItems([ + MetadataItem("yes", "category"), + ]))))]))], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), @@ -112,9 +115,10 @@ end elseif content[:model] == "mock-qa" user_msg = last(content[:messages]) response = Dict(:choices => [ - Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(QAItem("Question", - "Answer"))))), - ], + Dict(:finish_reason => "stop", + :message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(QAItem("Question", + "Answer"))))]))], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), @@ -122,14 +126,14 @@ end elseif content[:model] == "mock-judge" user_msg = last(content[:messages]) response = Dict(:choices => [ - Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(JudgeAllScores(5, - 5, - 5, - 5, - 5, - "Some reasons", - 5.0))))), - ], + Dict(:message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(JudgeAllScores(5, + 5, + 5, + 5, + 5, + "Some reasons", + 5.0))))]))], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), diff --git a/test/Experimental/RAGTools/generation.jl b/test/Experimental/RAGTools/generation.jl index 6fbb95057..b7fecd739 100644 --- a/test/Experimental/RAGTools/generation.jl +++ b/test/Experimental/RAGTools/generation.jl @@ -1,4 +1,6 @@ -using PromptingTools.Experimental.RAGTools: MaybeMetadataItems, MetadataItem, build_context +using PromptingTools.Experimental.RAGTools: ChunkIndex, + CandidateChunks, build_context, airag +using PromptingTools.Experimental.RAGTools: MaybeMetadataItems, MetadataItem @testset "build_context" begin index = ChunkIndex(; @@ -29,7 +31,7 @@ end @testset "airag" begin # test with a mock server - PORT = rand(1000:2000) + PORT = rand(20000:30000) PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema()) @@ -39,7 +41,9 @@ end if content[:model] == "mock-gen" user_msg = last(content[:messages]) - response = Dict(:choices => [Dict(:message => user_msg)], + response = Dict(:choices => [ + Dict(:message => user_msg, :finish_reason => "stop"), + ], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), @@ -53,10 +57,11 @@ end elseif content[:model] == "mock-meta" user_msg = last(content[:messages]) response = Dict(:choices => [ - Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(MaybeMetadataItems([ - MetadataItem("yes", "category"), - ]))))), - ], + Dict(:finish_reason => "stop", + :message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(MaybeMetadataItems([ + MetadataItem("yes", "category"), + ]))))]))], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), diff --git a/test/Experimental/RAGTools/preparation.jl b/test/Experimental/RAGTools/preparation.jl index f781c2e05..dc14b087e 100644 --- a/test/Experimental/RAGTools/preparation.jl +++ b/test/Experimental/RAGTools/preparation.jl @@ -82,7 +82,9 @@ end if content[:model] == "mock-gen" user_msg = last(content[:messages]) - response = Dict(:choices => [Dict(:message => user_msg)], + response = Dict(:choices => [ + Dict(:message => user_msg, :finish_reason => "stop"), + ], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), @@ -96,10 +98,11 @@ end elseif content[:model] == "mock-meta" user_msg = last(content[:messages]) response = Dict(:choices => [ - Dict(:message => Dict(:function_call => Dict(:arguments => JSON3.write(MaybeMetadataItems([ - MetadataItem("yes", "category"), - ]))))), - ], + Dict(:finish_reason => "stop", + :message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(MaybeMetadataItems([ + MetadataItem("yes", "category"), + ]))))]))], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), diff --git a/test/llm_interface.jl b/test/llm_interface.jl index 8e010bc11..d54cad0f4 100644 --- a/test/llm_interface.jl +++ b/test/llm_interface.jl @@ -1,12 +1,15 @@ using PromptingTools: TestEchoOpenAISchema, render, OpenAISchema using PromptingTools: AIMessage, SystemMessage, AbstractMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage +using PromptingTools: response_to_message, AbstractPromptSchema @testset "ai* default schema" begin OLD_PROMPT_SCHEMA = PromptingTools.PROMPT_SCHEMA ### AIGenerate # corresponds to OpenAI API v1 - response = Dict(:choices => [Dict(:message => Dict(:content => "Hello!"))], + response = Dict(:choices => [ + Dict(:message => Dict(:content => "Hello!"), :finish_reason => "stop"), + ], :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) schema = TestEchoOpenAISchema(; response, status = 200) @@ -16,6 +19,9 @@ using PromptingTools: UserMessage, UserMessageWithImages, DataMessage content = "Hello!" |> strip, status = 200, tokens = (2, 1), + run_id = msg.run_id, + finish_reason = "stop", + cost = 0.0, elapsed = msg.elapsed) @test msg == expected_output @@ -25,13 +31,18 @@ using PromptingTools: UserMessage, UserMessageWithImages, DataMessage content = nothing, status = 200, tokens = (2, 1), + run_id = msg.run_id, + cost = 0.0, + finish_reason = "stop", elapsed = msg.elapsed) @test msg == expected_output ### AIExtract response1 = Dict(:choices => [ - Dict(:message => Dict(:function_call => Dict(:arguments => "{\"content\": \"x\"}"))), - ], + Dict(:message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => "{\"content\": \"x\"}")), + ]), + :finish_reason => "stop")], :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) schema = TestEchoOpenAISchema(; response = response1, status = 200) @@ -44,6 +55,9 @@ using PromptingTools: UserMessage, UserMessageWithImages, DataMessage content = MyType("x"), status = 200, tokens = (2, 1), + run_id = msg.run_id, + cost = 0.0, + finish_reason = "stop", elapsed = msg.elapsed) @test msg == expected_output @@ -59,9 +73,18 @@ using PromptingTools: UserMessage, UserMessageWithImages, DataMessage content = ones(128), status = 200, tokens = (2, 0), + run_id = msg.run_id, + cost = 0.0, elapsed = msg.elapsed) @test msg == expected_output ## Return things to previous PromptingTools.PROMPT_SCHEMA = OLD_PROMPT_SCHEMA + + ## Check response_to_message throws by default + struct Random123Schema <: AbstractPromptSchema end + @test_throws ArgumentError response_to_message(Random123Schema(), + AIMessage, + nothing, + nothing) end diff --git a/test/llm_openai.jl b/test/llm_openai.jl index ba59c4097..71beb0ef2 100644 --- a/test/llm_openai.jl +++ b/test/llm_openai.jl @@ -3,7 +3,7 @@ using PromptingTools: AIMessage, SystemMessage, AbstractMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage using PromptingTools: CustomProvider, CustomOpenAISchema, MistralOpenAISchema, MODEL_EMBEDDING -using PromptingTools: encode_choices, decode_choices +using PromptingTools: encode_choices, decode_choices, response_to_message, call_cost @testset "render-OpenAI" begin schema = OpenAISchema() @@ -181,11 +181,18 @@ end @testset "OpenAI.create_chat" begin # Test CustomOpenAISchema() with a mock server - PORT = rand(1000:2000) + PORT = rand(10000:20000) echo_server = HTTP.serve!(PORT, verbose = -1) do req content = JSON3.read(req.body) user_msg = last(content[:messages]) - response = Dict(:choices => [Dict(:message => user_msg)], + response = Dict(:choices => [ + Dict(:message => user_msg, + :logprobs => Dict(:content => [ + Dict(:logprob => -0.1), + Dict(:logprob => -0.2), + ]), + :finish_reason => "stop"), + ], :model => content[:model], :usage => Dict(:total_tokens => length(user_msg[:content]), :prompt_tokens => length(user_msg[:content]), @@ -201,13 +208,18 @@ end return_all = false) @test msg.content == prompt @test msg.tokens == (length(prompt), 0) + @test msg.finish_reason == "stop" + ## single message, must be nothing + @test msg.sample_id |> isnothing + ## sum up log probs when provided + @test msg.log_prob ≈ -0.3 # clean up close(echo_server) end @testset "OpenAI.create_embeddings" begin # Test CustomOpenAISchema() with a mock server - PORT = rand(1000:2000) + PORT = rand(10000:20000) echo_server = HTTP.serve!(PORT, verbose = -1) do req content = JSON3.read(req.body) response = Dict(:data => [Dict(:embedding => ones(128))], @@ -230,9 +242,116 @@ end close(echo_server) end +@testset "response_to_message" begin + # Mock the response and choice data + mock_choice = Dict(:message => Dict(:content => "Hello!"), + :logprobs => Dict(:content => [Dict(:logprob => -0.5), Dict(:logprob => -0.4)]), + :finish_reason => "stop") + mock_response = (; + response = Dict(:choices => [mock_choice], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)), + status = 200) + + # Test with valid logprobs + msg = response_to_message(OpenAISchema(), + AIMessage, + mock_choice, + mock_response; + model_id = "gpt4t") + @test msg isa AIMessage + @test msg.content == "Hello!" + @test msg.tokens == (2, 1) + @test msg.log_prob ≈ -0.9 + @test msg.finish_reason == "stop" + @test msg.sample_id == nothing + @test msg.cost == call_cost(2, 1, "gpt4t") + + # Test without logprobs + choice = deepcopy(mock_choice) + delete!(choice, :logprobs) + msg = response_to_message(OpenAISchema(), AIMessage, choice, mock_response) + @test isnothing(msg.log_prob) + + # with sample_id and run_id + msg = response_to_message(OpenAISchema(), + AIMessage, + mock_choice, + mock_response; + run_id = 1, + sample_id = 2, + time = 2.0) + @test msg.run_id == 1 + @test msg.sample_id == 2 + @test msg.elapsed == 2.0 + + #### With DataMessage + # Mock the response and choice data + mock_choice = Dict(:message => Dict(:content => "Hello!", + :tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(Dict(:x => 1)))), + ]), + :logprobs => Dict(:content => [Dict(:logprob => -0.5), Dict(:logprob => -0.4)]), + :finish_reason => "stop") + mock_response = (; + response = Dict(:choices => [mock_choice], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)), + status = 200) + struct RandomType1235 + x::Int + end + return_type = RandomType1235 + # Catch missing return_type + @test_throws AssertionError response_to_message(OpenAISchema(), + DataMessage, + mock_choice, + mock_response; + model_id = "gpt4t") + + # Test with valid logprobs + msg = response_to_message(OpenAISchema(), + DataMessage, + mock_choice, + mock_response; + return_type, + model_id = "gpt4t") + @test msg isa DataMessage + @test msg.content == RandomType1235(1) + @test msg.tokens == (2, 1) + @test msg.log_prob ≈ -0.9 + @test msg.finish_reason == "stop" + @test msg.sample_id == nothing + @test msg.cost == call_cost(2, 1, "gpt4t") + + # Test without logprobs + choice = deepcopy(mock_choice) + delete!(choice, :logprobs) + msg = response_to_message(OpenAISchema(), + DataMessage, + choice, + mock_response; + return_type) + @test isnothing(msg.log_prob) + + # with sample_id and run_id + msg = response_to_message(OpenAISchema(), + DataMessage, + mock_choice, + mock_response; + return_type, + run_id = 1, + sample_id = 2, + time = 2.0) + @test msg.run_id == 1 + @test msg.sample_id == 2 + @test msg.elapsed == 2.0 +end + @testset "aigenerate-OpenAI" begin # corresponds to OpenAI API v1 - response = Dict(:choices => [Dict(:message => Dict(:content => "Hello!"))], + response = Dict(:choices => [ + Dict(:message => Dict(:content => "Hello!"), + :finish_reason => "stop"), + ], :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) # Test the monkey patch @@ -247,6 +366,8 @@ end content = "Hello!" |> strip, status = 200, tokens = (2, 1), + finish_reason = "stop", + cost = msg.cost, elapsed = msg.elapsed) @test msg == expected_output @test schema1.inputs == @@ -263,12 +384,30 @@ end content = "Hello!" |> strip, status = 200, tokens = (2, 1), + finish_reason = "stop", + cost = msg.cost, elapsed = msg.elapsed) @test msg == expected_output @test schema1.inputs == [Dict("role" => "system", "content" => "Act as a helpful AI assistant") Dict("role" => "user", "content" => "Hello World")] @test schema2.model_id == "gpt-4" + + ## Test multiple samples + response = Dict(:choices => [ + Dict(:message => Dict(:content => "Hello1!"), + :finish_reason => "stop"), + Dict(:message => Dict(:content => "Hello2!"), + :finish_reason => "stop"), + ], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + schema3 = TestEchoOpenAISchema(; response, status = 200) + conv = aigenerate(schema3, UserMessage("Hello {{name}}"), + model = "gpt4", http_kwargs = (; verbose = 3), + api_kwargs = (; temperature = 0, n = 2), + name = "World") + @test conv[end - 1].content == "Hello1!" + @test conv[end].content == "Hello2!" end @testset "aiembed-OpenAI" begin @@ -283,6 +422,7 @@ end content = ones(128), status = 200, tokens = (2, 0), + cost = msg.cost, elapsed = msg.elapsed) @test msg == expected_output @test schema1.inputs == "Hello World" @@ -298,6 +438,7 @@ end content = ones(128, 2), status = 200, tokens = (4, 0), + cost = msg.cost, elapsed = msg.elapsed) @test msg == expected_output @test schema2.inputs == ["Hello World", "Hello back"] @@ -307,6 +448,7 @@ end expected_output = DataMessage(; content = ones(128, 2), status = 200, + cost = msg.cost, tokens = (4, 0), elapsed = msg.elapsed) @test msg == expected_output @@ -381,6 +523,17 @@ end decoded_conv = decode_choices(OpenAISchema(), ["true", "false"], conv) @test decoded_conv[end].content == "true" + # Decode with multiple samples + conv = [ + AIMessage("1"), # do not touch, different run + AIMessage(; content = "1", run_id = 1, sample_id = 1), + AIMessage(; content = "1", run_id = 1, sample_id = 2), + ] + decoded_conv = decode_choices(OpenAISchema(), ["true", "false"], conv) + @test decoded_conv[1].content == "1" + @test decoded_conv[2].content == "true" + @test decoded_conv[3].content == "true" + # Nothing (when dry_run=true) @test isnothing(decode_choices(OpenAISchema(), ["true", "false"], nothing)) @@ -392,7 +545,10 @@ end @testset "aiclassify-OpenAI" begin # corresponds to OpenAI API v1 - response = Dict(:choices => [Dict(:message => Dict(:content => "1"))], + response = Dict(:choices => [ + Dict(:message => Dict(:content => "1"), + :finish_reason => "stop"), + ], :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) # Real generation API @@ -407,6 +563,8 @@ end content = "A", status = 200, tokens = (2, 1), + finish_reason = "stop", + cost = msg.cost, elapsed = msg.elapsed) @test msg == expected_output @test schema1.inputs == @@ -414,3 +572,94 @@ end "content" => "You are a world-class classification specialist. \n\nYour task is to select the most appropriate label from the given choices for the given user input.\n\n**Available Choices:**\n---\n1. \"A\" for any animal or creature\n2. \"P\" for for any plant or tree\n3. \"O\" for for everything else\n---\n\n**Instructions:**\n- You must respond in one word. \n- You must respond only with the label ID (e.g., \"1\", \"2\", ...) that best fits the input.\n"), Dict("role" => "user", "content" => "User Input: pelican\n\nLabel:\n")] end + +@testset "aiextract-OpenAI" begin + # mock return type + struct RandomType1235 + x::Int + end + return_type = RandomType1235 + + mock_choice = Dict(:message => Dict(:content => "Hello!", + :tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(Dict(:x => 1)))), + ]), + :logprobs => Dict(:content => [Dict(:logprob => -0.5), Dict(:logprob => -0.4)]), + :finish_reason => "stop") + ## Test with a single sample + response = Dict(:choices => [mock_choice], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + schema1 = TestEchoOpenAISchema(; response, status = 200) + msg = aiextract(schema1, "Extract number 1"; return_type, + model = "gpt4", + api_kwargs = (; temperature = 0, n = 2)) + @test msg.content == RandomType1235(1) + @test msg.log_prob ≈ -0.9 + + ## Test multiple samples -- mock_choice is less probable + mock_choice2 = Dict(:message => Dict(:content => "Hello!", + :tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(Dict(:x => 1)))), + ]), + :logprobs => Dict(:content => [Dict(:logprob => -1.2), Dict(:logprob => -0.4)]), + :finish_reason => "stop") + + response = Dict(:choices => [mock_choice, mock_choice2], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + schema2 = TestEchoOpenAISchema(; response, status = 200) + conv = aiextract(schema2, "Extract number 1"; return_type, + model = "gpt4", + api_kwargs = (; temperature = 0, n = 2)) + @test conv[1].content == RandomType1235(1) + @test conv[1].log_prob ≈ -1.6 # sorted first, despite sent later + @test conv[2].content == RandomType1235(1) + @test conv[2].log_prob ≈ -0.9 + + ## Wrong return_type so it returns a Dict + struct RandomType1236 + x::Int + y::Int + end + return_type = RandomType1236 + conv = aiextract(schema2, "Extract number 1"; return_type, + model = "gpt4", + api_kwargs = (; temperature = 0, n = 2)) + conv[1].content isa AbstractDict + conv[2].content isa AbstractDict +end + +@testset "aiscan-OpenAI" begin + ## Test with single sample and log_probs samples + response = Dict(:choices => [ + Dict(:message => Dict(:content => "Hello1!"), + :finish_reason => "stop", + :logprobs => Dict(:content => [ + Dict(:logprob => -0.1), + Dict(:logprob => -0.2), + ])), + ], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + schema1 = TestEchoOpenAISchema(; response, status = 200) + msg = aiscan(schema1, "Describe the image"; + image_url = "https://example.com/image.png", + model = "gpt4", http_kwargs = (; verbose = 3), + api_kwargs = (; temperature = 0)) + @test msg.content == "Hello1!" + @test msg.log_prob ≈ -0.3 + + ## Test multiple samples + response = Dict(:choices => [ + Dict(:message => Dict(:content => "Hello1!"), + :finish_reason => "stop"), + Dict(:message => Dict(:content => "Hello2!"), + :finish_reason => "stop"), + ], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + schema1 = TestEchoOpenAISchema(; response, status = 200) + conv = aiscan(schema1, "Describe the image"; + image_url = "https://example.com/image.png", + model = "gpt4", http_kwargs = (; verbose = 3), + api_kwargs = (; temperature = 0, n = 2)) + @test conv[end - 1].content == "Hello1!" + @test conv[end].content == "Hello2!" +end \ No newline at end of file diff --git a/test/llm_shared.jl b/test/llm_shared.jl index f449d9e5a..9c9c2b3a9 100644 --- a/test/llm_shared.jl +++ b/test/llm_shared.jl @@ -267,4 +267,32 @@ end conversation, return_all = true) @test output == expected_output + + ## With multiple samples + conversation = [ + SystemMessage("System message 1"), + UserMessage("User message {{name}}"), + AIMessage("AI message"), + ] + messages = [ + UserMessage("User message {{name}}"), + AIMessage("AI message 2"), + ] + msg = AIMessage("AI message 3") + expected_output = [ + SystemMessage("System message 1"), + UserMessage("User message {{name}}"), + AIMessage("AI message"), + UserMessage("User message John", [:name], :usermessage), + AIMessage("AI message 2"), + msg, + msg, + ] + output = finalize_outputs(messages, + [], + [msg, msg]; + name = "John", + conversation, + return_all = true) + @test output == expected_output end diff --git a/test/messages.jl b/test/messages.jl index 95cc5e15a..e90dcf467 100644 --- a/test/messages.jl +++ b/test/messages.jl @@ -1,7 +1,7 @@ using PromptingTools: AIMessage, SystemMessage, MetadataMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage using PromptingTools: _encode_local_image, attach_images_to_user_message -using PromptingTools: isusermessage, issystemmessage, isdatamessage +using PromptingTools: isusermessage, issystemmessage, isdatamessage, isaimessage @testset "Message constructors" begin # Creates an instance of MSG with the given content string. @@ -29,8 +29,8 @@ using PromptingTools: isusermessage, issystemmessage, isdatamessage @test UserMessage(content) |> isusermessage @test SystemMessage(content) |> issystemmessage @test DataMessage(; content) |> isdatamessage + @test AIMessage(; content) |> isaimessage end - @testset "UserMessageWithImages" begin content = "Hello, world!" image_path = joinpath(@__DIR__, "data", "julia.png") diff --git a/test/utils.jl b/test/utils.jl index a55310b84..5d6961cee 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -129,20 +129,23 @@ end end @testset "call_cost" begin + @test cost = call_cost(1000, 100, "unknown_model"; + cost_of_token_prompt = 1, + cost_of_token_generation = 1) ≈ 1100 msg = AIMessage(; content = "", tokens = (1000, 2000)) cost = call_cost(msg, "unknown_model") @test cost == 0.0 @test call_cost(msg, "gpt-3.5-turbo") ≈ 1000 * 0.5e-6 + 1.5e-6 * 2000 + # Test vector - same message, count once + @test call_cost([msg, msg], "gpt-3.5-turbo") ≈ (1000 * 0.5e-6 + 1.5e-6 * 2000) + msg2 = AIMessage(; content = "", tokens = (1000, 2000)) + @test call_cost([msg, msg2], "gpt-3.5-turbo") ≈ (1000 * 0.5e-6 + 1.5e-6 * 2000) * 2 + msg = DataMessage(; content = nothing, tokens = (1000, 1000)) cost = call_cost(msg, "unknown_model") @test cost == 0.0 @test call_cost(msg, "gpt-3.5-turbo") ≈ 1000 * 0.5e-6 + 1.5e-6 * 1000 - - @test call_cost(msg, - "gpt-3.5-turbo"; - cost_of_token_prompt = 1, - cost_of_token_generation = 1) ≈ 1000 + 1000 end @testset "report_stats" begin