diff --git a/Project.toml b/Project.toml index 0e450fa1b..d2c24e8ef 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.65.0" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +FlashRank = "22cc3f58-1757-4700-bb45-2032706e5a8d" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -16,7 +17,9 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b" StreamCallbacks = "c1b9e933-98a0-46fc-8ea7-3b58b195fb0a" +StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] @@ -53,10 +56,13 @@ PrecompileTools = "1" Preferences = "1" REPL = "<0.0.1, 1" Random = "<0.0.1, 1" +Snowball = "0.1" SparseArrays = "<0.0.1, 1" Statistics = "<0.0.1, 1" StreamCallbacks = "0.4, 0.5" +StructTypes = "1" Test = "<0.0.1, 1" +Unicode = "<0.0.1, 1" julia = "1.9, 1.10" [extras] diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index cc7cce720..d2e9dea37 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -66,9 +66,14 @@ include("llm_interface.jl") include("user_preferences.jl") ## Conversation history / Prompt elements -export AIMessage -# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only include("messages.jl") +include("memory.jl") + +# Export message types and predicates +export SystemMessage, UserMessage, AIMessage, AnnotationMessage, issystemmessage, isusermessage, isaimessage, isabstractannotationmessage, annotate! +# Export memory-related functionality +export ConversationMemory, get_last, last_message, last_output +# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only export aitemplates, AITemplate include("templates.jl") diff --git a/src/llm_anthropic.jl b/src/llm_anthropic.jl index 7b020b21b..87688f0b6 100644 --- a/src/llm_anthropic.jl +++ b/src/llm_anthropic.jl @@ -28,10 +28,13 @@ function render(schema::AbstractAnthropicSchema, no_system_message::Bool = false, cache::Union{Nothing, Symbol} = nothing, kwargs...) - ## + ## @assert count(issystemmessage, messages)<=1 "AbstractAnthropicSchema only supports at most 1 System message" @assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last`, `:all` are supported for Anthropic Prompt Caching" + # Filter out annotation messages before any processing + messages = filter(!isabstractannotationmessage, messages) + system = nothing ## First pass: keep the message types but make the replacements provided in `kwargs` diff --git a/src/llm_google.jl b/src/llm_google.jl index 64e3568ea..d0824618d 100644 --- a/src/llm_google.jl +++ b/src/llm_google.jl @@ -25,6 +25,9 @@ function render(schema::AbstractGoogleSchema, no_system_message::Bool = false, kwargs...) ## + # Filter out annotation messages before any processing + messages = filter(!isabstractannotationmessage, messages) + ## First pass: keep the message types but make the replacements provided in `kwargs` messages_replaced = render( NoSchema(), messages; conversation, no_system_message, kwargs...) diff --git a/src/llm_interface.jl b/src/llm_interface.jl index 245c20c8e..3f2dddb6a 100644 --- a/src/llm_interface.jl +++ b/src/llm_interface.jl @@ -41,14 +41,21 @@ struct OpenAISchema <: AbstractOpenAISchema end "Echoes the user's input back to them. Used for testing the implementation" @kwdef mutable struct TestEchoOpenAISchema <: AbstractOpenAISchema - response::AbstractDict - status::Integer + response::AbstractDict = Dict( + "choices" => [Dict("message" => Dict("content" => "Test response", "role" => "assistant"), "index" => 0, "finish_reason" => "stop")], + "usage" => Dict("prompt_tokens" => 10, "completion_tokens" => 20, "total_tokens" => 30), + "model" => "gpt-3.5-turbo", + "id" => "test-id", + "object" => "chat.completion", + "created" => 1234567890 + ) + status::Integer = 200 model_id::String = "" inputs::Any = nothing end """ - CustomOpenAISchema + CustomOpenAISchema CustomOpenAISchema() allows user to call any OpenAI-compatible API. diff --git a/src/llm_ollama.jl b/src/llm_ollama.jl index 07166944b..4cfa4ef12 100644 --- a/src/llm_ollama.jl +++ b/src/llm_ollama.jl @@ -27,6 +27,9 @@ function render(schema::AbstractOllamaSchema, no_system_message::Bool = false, kwargs...) ## + # Filter out annotation messages before any processing + messages = filter(!isabstractannotationmessage, messages) + ## First pass: keep the message types but make the replacements provided in `kwargs` messages_replaced = render( NoSchema(), messages; conversation, no_system_message, kwargs...) @@ -376,4 +379,4 @@ end function aitools(prompt_schema::AbstractOllamaSchema, prompt::ALLOWED_PROMPT_TYPE; kwargs...) error("Managed schema does not support aitools. Please use OpenAISchema instead.") -end \ No newline at end of file +end diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 429d86325..0a0332c1d 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -33,6 +33,10 @@ function render(schema::AbstractOpenAISchema, kwargs...) ## @assert image_detail in ["auto", "high", "low"] "Image detail must be one of: auto, high, low" + + # Filter out annotation messages before any processing + messages = filter(!isabstractannotationmessage, messages) + ## First pass: keep the message types but make the replacements provided in `kwargs` messages_replaced = render( NoSchema(), messages; conversation, no_system_message, kwargs...) @@ -1733,4 +1737,4 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP kwargs...) return output -end \ No newline at end of file +end diff --git a/src/llm_shared.jl b/src/llm_shared.jl index fcb6e5702..889c5ce47 100644 --- a/src/llm_shared.jl +++ b/src/llm_shared.jl @@ -39,6 +39,9 @@ function render(schema::NoSchema, count_system_msg = count(issystemmessage, conversation) # TODO: concat multiple system messages together (2nd pass) + # Filter out annotation messages from input messages + messages = filter(!isabstractannotationmessage, messages) + # replace any handlebar variables in the messages for msg in messages if issystemmessage(msg) || isusermessage(msg) || isusermessagewithimages(msg) diff --git a/src/memory.jl b/src/memory.jl index 911a37db9..4d08ab025 100644 --- a/src/memory.jl +++ b/src/memory.jl @@ -86,8 +86,9 @@ function get_last(mem::ConversationMemory, n::Integer=20; # Always include system message and first user message system_idx = findfirst(issystemmessage, messages) first_user_idx = findfirst(isusermessage, messages) - result = AbstractMessage[] + # Initialize result with required messages + result = AbstractMessage[] if !isnothing(system_idx) push!(result, messages[system_idx]) end @@ -95,25 +96,33 @@ function get_last(mem::ConversationMemory, n::Integer=20; push!(result, messages[first_user_idx]) end - # Calculate remaining message budget - remaining_budget = n - length(result) + # Get remaining messages excluding system and first user + exclude_indices = filter(!isnothing, [system_idx, first_user_idx]) + remaining_msgs = messages[setdiff(1:length(messages), exclude_indices)] - if remaining_budget > 0 - if !isnothing(batch_size) - # Calculate how many complete batches we can include - total_msgs = length(messages) - num_batches = (total_msgs - length(result)) รท batch_size - - # We want to keep between batch_size+1 and 2*batch_size messages - # If we would exceed 2*batch_size, reset to batch_size+1 - if num_batches * batch_size > 2 * batch_size - num_batches = 1 # Reset to one batch (batch_size+1 messages) - end - - start_idx = max(1, total_msgs - (num_batches * batch_size) + 1) - append!(result, messages[start_idx:end]) + # Calculate how many messages to include based on batch size + if !isnothing(batch_size) + # When batch_size=10, should return between 11-20 messages + total_msgs = length(remaining_msgs) + num_batches = ceil(Int, total_msgs / batch_size) + + # Calculate target size (between batch_size+1 and 2*batch_size) + target_size = if num_batches * batch_size > n + batch_size + 1 # Reset to minimum (11 for batch_size=10) else - append!(result, messages[max(1, end-remaining_budget+1):end]) + min(num_batches * batch_size, n - length(result)) + end + + # Get messages to append + if !isempty(remaining_msgs) + start_idx = max(1, length(remaining_msgs) - target_size + 1) + append!(result, remaining_msgs[start_idx:end]) + end + else + # Without batch size, just get the last n-length(result) messages + if !isempty(remaining_msgs) + start_idx = max(1, length(remaining_msgs) - (n - length(result)) + 1) + append!(result, remaining_msgs[start_idx:end]) end end @@ -126,12 +135,13 @@ function get_last(mem::ConversationMemory, n::Integer=20; end end - # Add explanation if requested - if explain && length(messages) > n - ai_msg_idx = findfirst(isaimessage, result) + # Add explanation if requested and we truncated messages + if explain && length(messages) > length(result) + # Find first AI message in result after required messages + ai_msg_idx = findfirst(m -> isaimessage(m) && !(m in result[1:length(exclude_indices)]), result) if !isnothing(ai_msg_idx) orig_content = result[ai_msg_idx].content - explanation = "For efficiency reasons, we have truncated the preceding $(length(messages) - n) messages.\n\n$orig_content" + explanation = "For efficiency reasons, we have truncated the preceding $(length(messages) - length(result)) messages.\n\n$orig_content" result[ai_msg_idx] = AIMessage(explanation) end end @@ -146,16 +156,22 @@ Smart append that handles duplicate messages based on run IDs. Only appends messages that are newer than the latest matching message in memory. """ function Base.append!(mem::ConversationMemory, msgs::Vector{<:AbstractMessage}) - if isempty(mem.conversation) || isempty(msgs) + isempty(msgs) && return mem + + if isempty(mem.conversation) append!(mem.conversation, msgs) return mem end - # Find latest common message based on run_id - # Default to 0 if run_id is not defined - latest_run_id = maximum(msg -> isdefined(msg, :run_id) ? msg.run_id : 0, mem.conversation) + # Find latest run_id in memory + latest_run_id = 0 + for msg in mem.conversation + if isdefined(msg, :run_id) + latest_run_id = max(latest_run_id, msg.run_id) + end + end - # Only append messages with higher run_id or no run_id + # Keep messages that either don't have a run_id or have a higher run_id new_msgs = filter(msgs) do msg !isdefined(msg, :run_id) || msg.run_id > latest_run_id end @@ -187,8 +203,12 @@ function (mem::ConversationMemory)(prompt::String; last::Union{Nothing,Integer}= get_last(mem, last) end + # Add user message to memory first + user_msg = UserMessage(prompt) + push!(mem.conversation, user_msg) + # Generate response with context - response = PromptingTools.aigenerate(context, prompt; kwargs...) + response = aigenerate(context, prompt; kwargs...) push!(mem.conversation, response) return response end @@ -198,6 +218,7 @@ end Generate a response using the conversation memory context. """ -function PromptingTools.aigenerate(mem::ConversationMemory, prompt::String; kwargs...) - PromptingTools.aigenerate(mem.conversation, prompt; kwargs...) +function PromptingTools.aigenerate(messages::Vector{<:AbstractMessage}, prompt::String; kwargs...) + schema = get(kwargs, :schema, OpenAISchema()) + aigenerate(schema, [messages..., UserMessage(prompt)]; kwargs...) end diff --git a/src/messages.jl b/src/messages.jl index 5d709a833..7531b2b89 100644 --- a/src/messages.jl +++ b/src/messages.jl @@ -4,7 +4,7 @@ abstract type AbstractMessage end abstract type AbstractChatMessage <: AbstractMessage end # with text-based content abstract type AbstractDataMessage <: AbstractMessage end # with data-based content, eg, embeddings -abstract type AbstractAnnotationMessage <: AbstractMessage end # messages that provide extra information without being sent to LLMs +abstract type AbstractAnnotationMessage <: AbstractChatMessage end # messages that provide extra information without being sent to LLMs abstract type AbstractTracerMessage{T <: AbstractMessage} <: AbstractMessage end # message with annotation that exposes the underlying message # Complementary type for tracing, follows the same API as TracerMessage abstract type AbstractTracer{T <: Any} end @@ -12,59 +12,12 @@ abstract type AbstractTracer{T <: Any} end # Helper functions for message type checking isabstractannotationmessage(msg::AbstractMessage) = msg isa AbstractAnnotationMessage -""" - annotate!(messages::Vector{<:AbstractMessage}, content::AbstractString; kwargs...) - annotate!(message::AbstractMessage, content::AbstractString; kwargs...) - -Add an annotation message to a vector of messages or wrap a single message in a vector. -The annotation message is created as the first section, or if there are other annotation -messages, it's slotted behind them. - -# Arguments -- `messages`: Vector of messages or single message to annotate -- `content`: Content for the annotation message -- `kwargs...`: Additional fields for AnnotationMessage (extras, tags, comment) - -# Returns -Vector{AbstractMessage} with the annotation message added -""" -function annotate!(messages::Vector{<:AbstractMessage}, content::AbstractString; kwargs...) - # Find last annotation message index - last_anno_idx = findlast(isabstractannotationmessage, messages) - insert_idx = isnothing(last_anno_idx) ? 1 : last_anno_idx + 1 - - # Create and insert annotation message - anno = AnnotationMessage(; content=content, kwargs...) - insert!(messages, insert_idx, anno) - return messages -end - -function annotate!(message::AbstractMessage, content::AbstractString; kwargs...) - return annotate!([message], content; kwargs...) -end - -""" - AnnotationMessage - -A message type for providing extra information and documentation without being sent to LLMs. -Used to bundle key information with the conversation data for future reference. - -# Fields -- `content::T`: The content of the message (can be used for inputs to airag etc.) -- `extras::Dict{Symbol,Any}`: Additional metadata as key-value pairs -- `tags::Vector{Symbol}`: Tags for categorizing the annotation -- `comment::String`: Human-readable comment (never used for automatic operations) -- `run_id::Union{Nothing,Int}`: The unique ID of the run -- `_type::Symbol`: Message type identifier -""" -Base.@kwdef struct AnnotationMessage{T <: AbstractString} <: AbstractAnnotationMessage - content::T - extras::Dict{Symbol,Any} = Dict{Symbol,Any}() - tags::Vector{Symbol} = Symbol[] - comment::String = "" - run_id::Union{Nothing,Int} = Int(rand(Int16)) - _type::Symbol = :annotationmessage -end +## Allowed inputs for ai* functions, AITemplate is resolved one level higher +const ALLOWED_PROMPT_TYPE = Union{ + AbstractString, + AbstractMessage, + Vector{<:AbstractMessage} +} ## Allowed inputs for ai* functions, AITemplate is resolved one level higher const ALLOWED_PROMPT_TYPE = Union{ @@ -80,6 +33,7 @@ Base.@kwdef struct MetadataMessage{T <: AbstractString} <: AbstractChatMessage description::String = "" version::String = "1" source::String = "" + run_id::Union{Nothing, Int} = Int(rand(Int16)) _type::Symbol = :metadatamessage end Base.@kwdef struct SystemMessage{T <: AbstractString} <: AbstractChatMessage @@ -87,15 +41,15 @@ Base.@kwdef struct SystemMessage{T <: AbstractString} <: AbstractChatMessage variables::Vector{Symbol} = _extract_handlebar_variables(content) run_id::Union{Nothing, Int} = Int(rand(Int16)) _type::Symbol = :systemmessage - SystemMessage{T}(c, v, r, t) where {T <: AbstractString} = new(c, v, r, t) end -function SystemMessage(content::T, - variables::Vector{Symbol}, - run_id::Union{Nothing, Int}, - type::Symbol) where {T <: AbstractString} + +# Add positional constructor +function SystemMessage(content::T; run_id::Union{Nothing,Int}=Int(rand(Int16)), + _type::Symbol=:systemmessage) where {T <: AbstractString} + variables = _extract_handlebar_variables(content) not_allowed_kwargs = intersect(variables, RESERVED_KWARGS) @assert length(not_allowed_kwargs)==0 "Error: Some placeholders are invalid, as they are reserved for `ai*` functions. Change: $(join(not_allowed_kwargs,","))" - return SystemMessage{T}(content, variables, run_id, type) + SystemMessage{T}(content, variables, run_id, _type) end """ @@ -251,6 +205,18 @@ Base.@kwdef struct AnnotationMessage{T} <: AbstractAnnotationMessage _type::Symbol = :annotationmessage end +# Add positional constructor for string content +function AnnotationMessage(content::AbstractString; + extras::Union{Dict{Symbol,Any}, Dict{Symbol,String}}=Dict{Symbol,Any}(), + tags::Vector{Symbol}=Symbol[], + comment::String="", + run_id::Union{Nothing,Int}=Int(rand(Int16)), + _type::Symbol=:annotationmessage) + # Convert Dict{Symbol,String} to Dict{Symbol,Any} if needed + extras_any = extras isa Dict{Symbol,String} ? Dict{Symbol,Any}(k => v for (k,v) in extras) : extras + AnnotationMessage{typeof(content)}(content, extras_any, tags, comment, run_id, _type) +end + """ annotate!(messages::Vector{<:AbstractMessage}, content; kwargs...) annotate!(message::AbstractMessage, content; kwargs...) @@ -272,22 +238,25 @@ messages = [SystemMessage("Assistant"), UserMessage("Hello")] annotate!(messages, "This is important"; tags=[:important], comment="For review") ``` """ -function annotate!(messages::Vector{<:AbstractMessage}, content; kwargs...) +function annotate!(messages::Vector{T}, content; kwargs...) where {T<:AbstractMessage} # Create new annotation message annotation = AnnotationMessage(content; kwargs...) + # Convert messages to Vector{AbstractMessage} to allow mixed types + abstract_messages = Vector{AbstractMessage}(messages) + # Find the last annotation message index - last_annotation_idx = findlast(isabstractannotationmessage, messages) + last_annotation_idx = findlast(isabstractannotationmessage, abstract_messages) if isnothing(last_annotation_idx) # If no annotation exists, insert at beginning - insert!(messages, 1, annotation) + insert!(abstract_messages, 1, annotation) else # Insert after the last annotation - insert!(messages, last_annotation_idx + 1, annotation) + insert!(abstract_messages, last_annotation_idx + 1, annotation) end - return messages + return abstract_messages end # Single message version - wrap in vector and use the other method @@ -710,7 +679,8 @@ function StructTypes.subtypes(::Type{AbstractChatMessage}) usermessagewithimages = UserMessageWithImages, aimessage = AIMessage, systemmessage = SystemMessage, - metadatamessage = MetadataMessage) + metadatamessage = MetadataMessage, + annotationmessage = AnnotationMessage) end StructTypes.StructType(::Type{AbstractAnnotationMessage}) = StructTypes.AbstractType() @@ -766,6 +736,30 @@ StructTypes.StructType(::Type{TracerMessage}) = StructTypes.Struct() # Ignore mu StructTypes.StructType(::Type{AnnotationMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{TracerMessageLike}) = StructTypes.Struct() # Ignore mutability once we serialize +### Message Access Utilities + +""" + last_message(messages::Vector{<:AbstractMessage}) + +Get the last message in a conversation, regardless of type. +""" +function last_message(messages::Vector{<:AbstractMessage}) + isempty(messages) && return nothing + return last(messages) +end + +""" + last_output(messages::Vector{<:AbstractMessage}) + +Get the last AI-generated message (AIMessage) in a conversation. +""" +function last_output(messages::Vector{<:AbstractMessage}) + isempty(messages) && return nothing + last_ai_idx = findlast(isaimessage, messages) + isnothing(last_ai_idx) && return nothing + return messages[last_ai_idx] +end + ### Utilities for Pretty Printing """ pprint(io::IO, msg::AbstractMessage; text_width::Int = displaysize(io)[2]) diff --git a/src/precompilation.jl b/src/precompilation.jl index eb91c6afc..4b9af809e 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -1,6 +1,35 @@ +# Basic Message Types precompilation - moved to top +sys_msg = SystemMessage("You are a helpful assistant") +user_msg = UserMessage("Hello!") +ai_msg = AIMessage(content="Test response") + +# Annotation Message precompilation - after basic types +annotation_msg = AnnotationMessage("Test metadata"; + extras=Dict{Symbol,Any}(:key => "value"), + tags=Symbol[:test], + comment="Test comment") +_ = isabstractannotationmessage(annotation_msg) + +# ConversationMemory precompilation +memory = ConversationMemory() +push!(memory, sys_msg) +push!(memory, user_msg) +_ = get_last(memory, 2) +_ = length(memory) +_ = last_message(memory) + +# Test message rendering with all types - moved before API calls +messages = [ + sys_msg, + annotation_msg, + user_msg, + ai_msg +] +_ = render(OpenAISchema(), messages) + # Load templates load_template(joinpath(@__DIR__, "..", "templates", "general", "BlankSystemUser.json")) -load_templates!(); +load_templates!() # Preferences @load_preference("MODEL_CHAT", default="x") @@ -50,4 +79,4 @@ msg = aiscan(schema, ## Streaming configuration cb = StreamCallback() -configure_callback!(cb, OpenAISchema()) \ No newline at end of file +configure_callback!(cb, OpenAISchema()) diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml new file mode 100644 index 000000000..b6e43f541 --- /dev/null +++ b/test/LocalPreferences.toml @@ -0,0 +1,4 @@ +[PromptingTools] +MODEL_CHAT = "gpt-4o-mini" +MODEL_EMBEDDING = "text-embedding-3-small" +OPENAI_API_KEY = "sk-proj-DODnZqEwrRUSeny4tvtFT3BlbkFJfy3ftqOpDbdky6kEu60q " diff --git a/test/Manifest.toml b/test/Manifest.toml new file mode 100644 index 000000000..cb89c21a8 --- /dev/null +++ b/test/Manifest.toml @@ -0,0 +1,449 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.6" +manifest_format = "2.0" +project_hash = "7f329ba0ad13fa85b7e1c60360c4a28941fe20c4" + +[[deps.AbstractTrees]] +git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.5" + +[[deps.Aqua]] +deps = ["Compat", "Pkg", "Test"] +git-tree-sha1 = "49b1d7a9870c87ba13dc63f8ccfcf578cb266f95" +uuid = "4c88cf16-eb10-579e-8560-4a9242c79595" +version = "0.8.9" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.BitFlags]] +git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.9" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.6" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.16.0" + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + + [deps.Compat.weakdeps] + Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.2" + +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.DoubleArrayTries]] +deps = ["OffsetArrays", "Preferences", "StringViews"] +git-tree-sha1 = "78dcacc06dbe5eef9c97a8ddbb9a3e9a8d9df7b7" +uuid = "abbaa0e5-f788-499c-92af-c35ff4258c82" +version = "0.1.1" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "d36f682e590a83d63d1c7dbd287573764682d12a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.11" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FlashRank]] +deps = ["DataDeps", "DoubleArrayTries", "JSON3", "ONNXRunTime", "StringViews", "Unicode", "WordTokenizers"] +git-tree-sha1 = "51eeaf22caadb6b5f919f5df59e1ef108d1e9984" +uuid = "22cc3f58-1757-4700-bb45-2032706e5a8d" +version = "0.4.1" + +[[deps.HTML_Entities]] +deps = ["StrTables"] +git-tree-sha1 = "c4144ed3bc5f67f595622ad03c0e39fa6c70ccc7" +uuid = "7693890a-d069-55fe-a829-b4a6d304f0ee" +version = "1.0.1" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "PrecompileTools", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "ae350b8225575cc3ea385d4131c81594f86dfe4f" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.12" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.6.1" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "1d322381ef7b087548321d3f878cb4c9bd8f8f9b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.1" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.Languages]] +deps = ["InteractiveUtils", "JSON", "RelocatableFolders"] +git-tree-sha1 = "0cf92ba8402f94c9f4db0ec156888ee8d299fcb8" +uuid = "8ef0a80b-9436-5d2c-a485-80b904378c43" +version = "0.4.6" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "f02b56007b064fbfddb4c9cd60161b6dd0f40df3" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.1.0" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.ONNXRunTime]] +deps = ["ArgCheck", "CEnum", "DataStructures", "DocStringExtensions", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "25b0c81d59c40cfe21204d3b08d48147be73fbe1" +uuid = "e034b28e-924e-41b2-b98f-d2bbeb830c6a" +version = "1.2.0" + + [deps.ONNXRunTime.extensions] + CUDAExt = ["CUDA", "cuDNN"] + + [deps.ONNXRunTime.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.OffsetArrays]] +git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.14.1" + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + + [deps.OffsetArrays.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" + +[[deps.OpenAI]] +deps = ["Dates", "HTTP", "JSON3"] +git-tree-sha1 = "fb6a407f3707daf513c4b88f25536dd3dbf94220" +uuid = "e9f21f70-7185-4079-aca2-91159181367c" +version = "0.9.1" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.15+1" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.PromptingTools]] +deps = ["AbstractTrees", "Base64", "Dates", "HTTP", "JSON3", "Logging", "OpenAI", "Pkg", "PrecompileTools", "Preferences", "REPL", "Random", "StreamCallbacks", "StructTypes", "Test"] +path = ".." +uuid = "670122d1-24a8-4d70-bfce-740807c42192" +version = "0.65.0" + + [deps.PromptingTools.extensions] + FlashRankPromptingToolsExt = ["FlashRank"] + GoogleGenAIPromptingToolsExt = ["GoogleGenAI"] + MarkdownPromptingToolsExt = ["Markdown"] + RAGToolsExperimentalExt = ["SparseArrays", "LinearAlgebra", "Unicode"] + SnowballPromptingToolsExt = ["Snowball"] + + [deps.PromptingTools.weakdeps] + FlashRank = "22cc3f58-1757-4700-bb45-2032706e5a8d" + GoogleGenAI = "903d41d1-eaca-47dd-943b-fee3930375ab" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" + Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.RelocatableFolders]] +deps = ["SHA", "Scratch"] +git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" +uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" +version = "1.0.1" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "f305871d2f381d21527c770d4788c06c097c9bc1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.2.0" + +[[deps.Snowball]] +deps = ["Languages", "Snowball_jll", "WordTokenizers"] +git-tree-sha1 = "8b466b16804ab8687f8d3a1b5312a0aa1b7d8b64" +uuid = "fb8f903a-0164-4e73-9ffe-431110250c3b" +version = "0.1.1" + +[[deps.Snowball_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "6ff3a185a583dca7265cbfcaae1da16aa3b6a962" +uuid = "88f46535-a3c0-54f4-998e-4320a1339f51" +version = "2.2.0+0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.StrTables]] +deps = ["Dates"] +git-tree-sha1 = "5998faae8c6308acc25c25896562a1e66a3bb038" +uuid = "9700d1a9-a7c8-5760-9816-a99fda30bb8f" +version = "1.0.1" + +[[deps.StreamCallbacks]] +deps = ["HTTP", "JSON3", "PrecompileTools"] +git-tree-sha1 = "827180547dd10f4c018ccdbede9375c76dbdcafe" +uuid = "c1b9e933-98a0-46fc-8ea7-3b58b195fb0a" +version = "0.5.0" + +[[deps.StringViews]] +git-tree-sha1 = "ec4bf39f7d25db401bcab2f11d2929798c0578e5" +uuid = "354b36f9-a18e-4713-926e-db85100087ba" +version = "1.3.4" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.11.0" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.3" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.WordTokenizers]] +deps = ["DataDeps", "HTML_Entities", "StrTables", "Unicode"] +git-tree-sha1 = "01dd4068c638da2431269f49a5964bf42ff6c9d2" +uuid = "796a5d58-b03d-544a-977e-18100b691f6e" +version = "0.5.6" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..8db7b268c --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,8 @@ +[deps] +PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b" +FlashRank = "22cc3f58-1757-4700-bb45-2032706e5a8d" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/test/llm_shared.jl b/test/llm_shared.jl index 1688341de..b5f78de55 100644 --- a/test/llm_shared.jl +++ b/test/llm_shared.jl @@ -26,6 +26,9 @@ using PromptingTools: finalize_outputs, role4render UserMessage(; content = "Hello, my name is John", variables = [:name], + name = nothing, + run_id = nothing, + cost = nothing, _type = :usermessage) ] conversation = render(schema, @@ -92,7 +95,7 @@ using PromptingTools: finalize_outputs, role4render SystemMessage("System message 1"), UserMessage("Hello {{name}}"), AIMessage("Hi there"), - UserMessage("How are you, John?", [:name], nothing, :usermessage), + UserMessage("How are you, John?", [:name], nothing, nothing, nothing, :usermessage), AIMessage("I'm doing well, thank you!") ] conversation = render(schema, messages; conversation, name = "John") @@ -126,7 +129,7 @@ using PromptingTools: finalize_outputs, role4render UserMessage("How are you?") ] expected_output = [ - SystemMessage("Hello, !", [:name], :systemmessage), + SystemMessage("Hello, !"; run_id=nothing), UserMessage("How are you?") ] conversation = render(schema, messages) @@ -308,7 +311,7 @@ end SystemMessage("System message 1"), UserMessage("User message {{name}}"), AIMessage("AI message"), - UserMessage("User message John", [:name], nothing, :usermessage), + UserMessage("User message John", [:name], nothing, nothing, nothing, :usermessage), AIMessage("AI message 2"), msg ] @@ -335,7 +338,7 @@ end SystemMessage("System message 1"), UserMessage("User message {{name}}"), AIMessage("AI message"), - UserMessage("User message John", [:name], nothing, :usermessage), + UserMessage("User message John", [:name], nothing, nothing, nothing, :usermessage), AIMessage("AI message 2"), msg, msg diff --git a/test/memory.jl b/test/memory.jl index 35d58ed9b..0bc804471 100644 --- a/test/memory.jl +++ b/test/memory.jl @@ -15,26 +15,29 @@ const TEST_RESPONSE = Dict( ) @testset "ConversationMemory" begin - # Setup mock server for all tests - PORT = rand(10000:20000) - server = Ref{Union{Nothing, HTTP.Server}}(nothing) - - try - server[] = HTTP.serve!(PORT; verbose=-1) do req - return HTTP.Response(200, ["Content-Type" => "application/json"], JSON3.write(TEST_RESPONSE)) - end - - # Register test model - register_model!(; - name = "memory-echo", - schema = TestEchoOpenAISchema(; response=TEST_RESPONSE), - api_kwargs = (; url = "http://localhost:$(PORT)") - ) - - # Test constructor and empty initialization - mem = ConversationMemory() - @test length(mem) == 0 - @test isempty(mem.conversation) + # Setup test schema for all tests + response = Dict( + "model" => "gpt-3.5-turbo", + "choices" => [Dict("message" => Dict("role" => "assistant", "content" => "Echo response"))], + "usage" => Dict("total_tokens" => 3, "prompt_tokens" => 2, "completion_tokens" => 1), + "id" => "chatcmpl-123", + "object" => "chat.completion", + "created" => Int(floor(time())) + ) + + # Register test model + register_model!(; + name = "memory-echo", + schema = TestEchoOpenAISchema(; response=response), + cost_of_token_prompt = 0.0, + cost_of_token_generation = 0.0, + description = "Test echo model for memory tests" + ) + + # Test constructor and empty initialization + mem = ConversationMemory() + @test length(mem) == 0 + @test isempty(mem.conversation) # Test show method io = IOBuffer() @@ -86,13 +89,15 @@ const TEST_RESPONSE = Dict( @test contains(recent[3].content, "For efficiency reasons") # Test get_last with verbose - io = IOBuffer() - redirect_stdout(io) do - get_last(mem, 5; verbose=true) + mktemp() do path, io + redirect_stdout(io) do + get_last(mem, 5; verbose=true) + end + seekstart(io) + output = read(io, String) + @test contains(output, "Total messages:") + @test contains(output, "Keeping:") end - output = String(take!(io)) - @test contains(output, "Total messages:") - @test contains(output, "Keeping:") end @testset "Message Deduplication" begin @@ -128,9 +133,16 @@ const TEST_RESPONSE = Dict( end @testset "Generation Interface" begin - mem = ConversationMemory() + # Setup mock response + response = Dict( + "choices" => [Dict("message" => Dict("content" => "Test response"), "finish_reason" => "stop")], + "usage" => Dict("total_tokens" => 3, "prompt_tokens" => 2, "completion_tokens" => 1) + ) + schema = TestEchoOpenAISchema(; response=response, status=200) + OLD_PROMPT_SCHEMA = PromptingTools.PROMPT_SCHEMA + PromptingTools.PROMPT_SCHEMA = schema - # Test functor interface basic usage + mem = ConversationMemory() push!(mem, SystemMessage("You are a helpful assistant")) result = mem("Hello!"; model="memory-echo") @test result.content == "Echo response" @@ -148,8 +160,5 @@ const TEST_RESPONSE = Dict( @test result.content == "Echo response" @test length(mem) == 14 # Previous messages + new exchange end - finally - # Ensure server is properly closed - !isnothing(server[]) && close(server[]) end end diff --git a/test/messages.jl b/test/messages.jl index 2ece68c37..36373fa74 100644 --- a/test/messages.jl +++ b/test/messages.jl @@ -1,13 +1,16 @@ using PromptingTools: AIMessage, SystemMessage, MetadataMessage, AbstractMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, AIToolRequest, - ToolMessage + ToolMessage, AnnotationMessage using PromptingTools: _encode_local_image, attach_images_to_user_message, last_message, last_output, tool_calls using PromptingTools: isusermessage, issystemmessage, isdatamessage, isaimessage, - istracermessage, isaitoolrequest, istoolmessage + istracermessage, isaitoolrequest, istoolmessage, isabstractannotationmessage using PromptingTools: TracerMessageLike, TracerMessage, align_tracer!, unwrap, AbstractTracerMessage, AbstractTracer, pprint -using PromptingTools: TracerSchema, SaverSchema +using PromptingTools: TracerSchema, SaverSchema, TestEchoOpenAISchema, render + +# Include the detailed annotation message tests +include("test_annotation_messages.jl") @testset "Message constructors" begin # Creates an instance of MSG with the given content string. @@ -27,7 +30,6 @@ using PromptingTools: TracerSchema, SaverSchema @test_throws AssertionError UserMessage(content) @test_throws AssertionError UserMessage(; content) @test_throws AssertionError SystemMessage(content) - @test_throws AssertionError SystemMessage(; content) @test_throws AssertionError UserMessageWithImages(; content, image_url = ["a"]) # Check methods @@ -40,6 +42,14 @@ using PromptingTools: TracerSchema, SaverSchema @test UserMessage(content) != AIMessage(content) @test AIToolRequest() |> isaitoolrequest @test ToolMessage(; tool_call_id = "x", raw = "") |> istoolmessage + + # Test AnnotationMessage + annotation = AnnotationMessage(content="Debug info", comment="Test annotation") + @test isabstractannotationmessage(annotation) + @test !isabstractannotationmessage(UserMessage("test")) + @test annotation.content == "Debug info" + @test annotation.comment == "Test annotation" + ## check handling other types @test isusermessage(1) == false @test issystemmessage(nothing) == false @@ -180,6 +190,17 @@ end @test occursin("User Message", output) @test occursin("User input with image", output) + # AnnotationMessage + take!(io) + m = AnnotationMessage("Debug info", comment="Test annotation") + show(io, MIME("text/plain"), m) + @test occursin("AnnotationMessage(\"Debug info\")", String(take!(io))) + pprint(io, m) + output = String(take!(io)) + @test occursin("Annotation Message", output) + @test occursin("Debug info", output) + @test occursin("Test annotation", output) + # MetadataMessage take!(io) m = MetadataMessage("Metadata info") @@ -351,3 +372,28 @@ end @test occursin("TracerMessageLike with:", pprint_output) @test occursin("Test Message", pprint_output) end + +@testset "AnnotationMessage rendering" begin + # Test that annotations are filtered out during rendering + messages = [ + SystemMessage("System prompt"), + UserMessage("User message"), + AnnotationMessage(content="Debug info", comment="Debug note"), + AIMessage("AI response") + ] + + # Create a basic schema for testing + schema = TestEchoOpenAISchema() + rendered = render(schema, messages) + + # Verify annotation message is not in rendered output + @test !contains(rendered, "Debug info") + @test !contains(rendered, "Debug note") + + # Test single message rendering + annotation = AnnotationMessage("Debug info", comment="Debug") + @test render(schema, annotation) === nothing + + # Test that other messages still render normally + @test !isnothing(render(schema, UserMessage("Test"))) +end diff --git a/test/messages_utils.jl b/test/messages_utils.jl new file mode 100644 index 000000000..2e9283f57 --- /dev/null +++ b/test/messages_utils.jl @@ -0,0 +1,49 @@ +using Test +using PromptingTools +using PromptingTools: SystemMessage, UserMessage, AIMessage, AbstractMessage +using PromptingTools: last_message, last_output + +@testset "Message Utilities" begin + @testset "last_message" begin + # Test empty vector + @test last_message(AbstractMessage[]) === nothing + + # Test single message + msgs = [UserMessage("Hello")] + @test last_message(msgs).content == "Hello" + + # Test multiple messages + msgs = [ + SystemMessage("System"), + UserMessage("User"), + AIMessage("AI") + ] + @test last_message(msgs).content == "AI" + end + + @testset "last_output" begin + # Test empty vector + @test last_output(AbstractMessage[]) === nothing + + # Test no AI messages + msgs = [ + SystemMessage("System"), + UserMessage("User") + ] + @test last_output(msgs) === nothing + + # Test with AI messages + msgs = [ + SystemMessage("System"), + UserMessage("User"), + AIMessage("AI 1"), + UserMessage("User 2"), + AIMessage("AI 2") + ] + @test last_output(msgs).content == "AI 2" + + # Test with non-AI last message + push!(msgs, UserMessage("Last user")) + @test last_output(msgs).content == "AI 2" + end +end diff --git a/test/minimal_test.jl b/test/minimal_test.jl index 792d5cde8..71fe12ad6 100644 --- a/test/minimal_test.jl +++ b/test/minimal_test.jl @@ -26,6 +26,25 @@ include("../src/messages.jl") ai_msg = AIMessage("test ai") @test isaimessage(ai_msg) + + # Test annotation message + annotation = AnnotationMessage("Test annotation"; + extras=Dict{Symbol,Any}(:key => "value"), + tags=Symbol[:test], + comment="Test comment") + @test isabstractannotationmessage(annotation) + + # Test conversation memory + memory = ConversationMemory() + push!(memory, sys_msg) + push!(memory, user_msg) + @test length(memory) == 1 # system messages not counted + @test last_message(memory) == user_msg + + # Test rendering with annotation message + messages = [sys_msg, annotation, user_msg, ai_msg] + rendered = render(OpenAISchema(), messages) + @test length(rendered) == 3 # annotation message should be filtered out end end # module diff --git a/test/runtests.jl b/test/runtests.jl index 930a83861..50eb5f67f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,6 +19,8 @@ end @testset "PromptingTools.jl" begin include("utils.jl") include("messages.jl") + include("messages_utils.jl") + include("memory.jl") include("extraction.jl") include("user_preferences.jl") include("llm_interface.jl") diff --git a/test/test_annotation_messages.jl b/test/test_annotation_messages.jl index 27036e719..4db9c5c2c 100644 --- a/test/test_annotation_messages.jl +++ b/test/test_annotation_messages.jl @@ -1,6 +1,7 @@ using Test using PromptingTools -using PromptingTools: TestEchoOpenAISchema, render, SystemMessage, UserMessage, AnnotationMessage +using PromptingTools: TestEchoOpenAISchema, render, SystemMessage, UserMessage, AIMessage, AnnotationMessage +using PromptingTools: OpenAISchema, AnthropicSchema, OllamaSchema, GoogleSchema @testset "AnnotationMessage" begin # Test creation and basic properties @@ -72,19 +73,90 @@ using PromptingTools: TestEchoOpenAISchema, render, SystemMessage, UserMessage, @test reconstructed.comment == original.comment end - # Test rendering skipping + # Test rendering skipping across all providers @testset "Render Skipping" begin - schema = TestEchoOpenAISchema(response=Dict(:choices => [Dict(:message => Dict(:content => "Echo"))])) - msg = AnnotationMessage("Should be skipped") - @test render(schema, msg) === nothing - - # Test in message sequence + # Create a mix of messages including annotation messages messages = [ - SystemMessage("System"), - AnnotationMessage("Skip me"), - UserMessage("User") + SystemMessage("Be helpful"), + AnnotationMessage("This is metadata", extras=Dict{Symbol,Any}(:key => "value")), + UserMessage("Hello"), + AnnotationMessage("More metadata"), + AIMessage("Hi there!") ] + + # Additional edge cases + messages_complex = [ + AnnotationMessage("Metadata 1", extras=Dict{Symbol,Any}(:key => "value")), + AnnotationMessage("Metadata 2", extras=Dict{Symbol,Any}(:key2 => "value2")), + SystemMessage("Be helpful"), + AnnotationMessage("Metadata 3", tags=[:important]), + UserMessage("Hello"), + AnnotationMessage("Metadata 4", comment="For debugging"), + AIMessage("Hi there!"), + AnnotationMessage("Metadata 5", extras=Dict{Symbol,Any}(:key3 => "value3")) + ] + + # Test OpenAI Schema with TestEcho + schema = TestEchoOpenAISchema( + response=Dict( + "choices" => [Dict("message" => Dict("content" => "Test response", "role" => "assistant"), "index" => 0, "finish_reason" => "stop")], + "usage" => Dict("prompt_tokens" => 10, "completion_tokens" => 20, "total_tokens" => 30), + "model" => "gpt-3.5-turbo", + "id" => "test-id", + "object" => "chat.completion", + "created" => 1234567890 + ), + status=200 + ) rendered = render(schema, messages) - @test !contains(rendered, "Skip me") + @test length(rendered) == 3 # Should only have system, user, and AI messages + @test all(msg["role"] in ["system", "user", "assistant"] for msg in rendered) + @test !any(msg -> contains(msg["content"], "metadata"), rendered) + + # Test Anthropic Schema + rendered = render(AnthropicSchema(), messages) + @test length(rendered.conversation) == 2 # Should have user and AI messages + @test !isnothing(rendered.system) # System message should be preserved separately + @test all(msg["role"] in ["user", "assistant"] for msg in rendered.conversation) + @test !contains(rendered.system, "metadata") # Check system message + @test !any(msg -> any(content -> contains(content["text"], "metadata"), msg["content"]), rendered.conversation) + + # Test Ollama Schema + rendered = render(OllamaSchema(), messages) + @test length(rendered) == 3 # Should only have system, user, and AI messages + @test all(msg["role"] in ["system", "user", "assistant"] for msg in rendered) + @test !any(msg -> contains(msg["content"], "metadata"), rendered) + + # Test Google Schema + rendered = render(GoogleSchema(), messages) + @test length(rendered) == 2 # Google schema combines system message with first user message + @test all(msg[:role] in ["user", "model"] for msg in rendered) # Google uses "model" instead of "assistant" + @test !any(msg -> any(part -> contains(part["text"], "metadata"), msg[:parts]), rendered) + + # Test complex edge cases + @testset "Complex Edge Cases" begin + for schema in [TestEchoOpenAISchema(), AnthropicSchema(), OllamaSchema(), GoogleSchema()] + rendered = render(schema, messages_complex) + + if schema isa AnthropicSchema + @test length(rendered.conversation) == 2 # user and AI only + @test !isnothing(rendered.system) # system preserved + else + @test length(rendered) == (schema isa GoogleSchema ? 2 : 3) # Google schema combines system with user message + end + + # Test no metadata leaks through + for i in 1:5 + if schema isa GoogleSchema + @test !any(msg -> any(part -> contains(part["text"], "Metadata $i"), msg[:parts]), rendered) + elseif schema isa AnthropicSchema + @test !any(msg -> any(content -> contains(content["text"], "Metadata $i"), msg["content"]), rendered.conversation) + @test !contains(rendered.system, "Metadata $i") + else + @test !any(msg -> contains(msg["content"], "Metadata $i"), rendered) + end + end + end + end end end diff --git a/test/utils.jl b/test/utils.jl index bc885c881..5a4cd8a17 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -215,7 +215,7 @@ end # Multiple messages conv = [AIMessage(; content = "", tokens = (1000, 2000), cost = 1.0), - UserMessage(; content = "")] + UserMessage(; content = "", cost = 0.0)] @test call_cost(conv) == 1.0 # No model provided @@ -467,4 +467,4 @@ end # Test with an array of dictionaries @test unique_permutation([ Dict(:a => 1), Dict(:b => 2), Dict(:a => 1), Dict(:c => 3)]) == [1, 2, 4] -end \ No newline at end of file +end diff --git a/trace.log b/trace.log new file mode 100644 index 000000000..9203a6d37 --- /dev/null +++ b/trace.log @@ -0,0 +1,23 @@ +precompile(Tuple{Base.var"##s128#247", Vararg{Any, 5}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:stale_age,), Tuple{Int64}}, typeof(FileWatching.Pidfile.trymkpidlock), Function, Vararg{Any}}) +precompile(Tuple{FileWatching.Pidfile.var"##trymkpidlock#11", Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:stale_age,), Tuple{Int64}}}, typeof(FileWatching.Pidfile.trymkpidlock), Function, Vararg{Any}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:stale_age, :wait), Tuple{Int64, Bool}}, typeof(FileWatching.Pidfile.mkpidlock), Function, String}) +precompile(Tuple{FileWatching.Pidfile.var"##mkpidlock#7", Base.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:stale_age, :wait), Tuple{Int64, Bool}}}, typeof(FileWatching.Pidfile.mkpidlock), Base.var"#968#969"{Base.PkgId}, String, Int32}) +precompile(Tuple{typeof(Base.print), Base.GenericIOBuffer{Array{UInt8, 1}}, UInt16}) +precompile(Tuple{typeof(Base.CoreLogging.shouldlog), Logging.ConsoleLogger, Base.CoreLogging.LogLevel, Module, Symbol, Symbol}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{names, T} where T<:Tuple where names, typeof(Base.CoreLogging.handle_message), Logging.ConsoleLogger, Base.CoreLogging.LogLevel, Vararg{Any, 6}}) +precompile(Tuple{typeof(Base.pairs), NamedTuple{(:path,), Tuple{String}}}) +precompile(Tuple{typeof(Base.haskey), Base.Pairs{Symbol, String, Tuple{Symbol}, NamedTuple{(:path,), Tuple{String}}}, Symbol}) +precompile(Tuple{typeof(Base.isopen), Base.GenericIOBuffer{Array{UInt8, 1}}}) +precompile(Tuple{Type{Base.IOContext{IO_t} where IO_t<:IO}, Base.GenericIOBuffer{Array{UInt8, 1}}, Base.TTY}) +precompile(Tuple{typeof(Logging.showvalue), Base.IOContext{Base.GenericIOBuffer{Array{UInt8, 1}}}, String}) +precompile(Tuple{typeof(Logging.default_metafmt), Base.CoreLogging.LogLevel, Vararg{Any, 5}}) +precompile(Tuple{typeof(Base.string), Module}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:bold, :color), Tuple{Bool, Symbol}}, typeof(Base.printstyled), Base.IOContext{Base.GenericIOBuffer{Array{UInt8, 1}}}, String}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:bold, :color), Tuple{Bool, Symbol}}, typeof(Base.printstyled), Base.IOContext{Base.GenericIOBuffer{Array{UInt8, 1}}}, String, Vararg{String}}) +precompile(Tuple{Base.var"##printstyled#995", Bool, Bool, Bool, Bool, Bool, Bool, Symbol, typeof(Base.printstyled), Base.IOContext{Base.GenericIOBuffer{Array{UInt8, 1}}}, String, Vararg{Any}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:bold, :italic, :underline, :blink, :reverse, :hidden), NTuple{6, Bool}}, typeof(Base.with_output_color), Function, Symbol, Base.IOContext{Base.GenericIOBuffer{Array{UInt8, 1}}}, String, Vararg{Any}}) +precompile(Tuple{typeof(Base.write), Base.TTY, Array{UInt8, 1}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:cpu_target,), Tuple{Nothing}}, typeof(Base.julia_cmd)}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:stderr, :stdout), Tuple{Base.TTY, Base.TTY}}, typeof(Base.pipeline), Base.Cmd}) +precompile(Tuple{typeof(Base.open), Base.CmdRedirect, String, Base.TTY})