Skip to content

Commit

Permalink
Fix AnnotationMessage tests and implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-ai-integration[bot] committed Nov 25, 2024
1 parent 2ebb69e commit 3e9e4f1
Show file tree
Hide file tree
Showing 23 changed files with 920 additions and 158 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.65.0"
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FlashRank = "22cc3f58-1757-4700-bb45-2032706e5a8d"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -16,7 +17,9 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Snowball = "fb8f903a-0164-4e73-9ffe-431110250c3b"
StreamCallbacks = "c1b9e933-98a0-46fc-8ea7-3b58b195fb0a"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
Expand Down Expand Up @@ -53,10 +56,13 @@ PrecompileTools = "1"
Preferences = "1"
REPL = "<0.0.1, 1"
Random = "<0.0.1, 1"
Snowball = "0.1"
SparseArrays = "<0.0.1, 1"
Statistics = "<0.0.1, 1"
StreamCallbacks = "0.4, 0.5"
StructTypes = "1"
Test = "<0.0.1, 1"
Unicode = "<0.0.1, 1"
julia = "1.9, 1.10"

[extras]
Expand Down
9 changes: 7 additions & 2 deletions src/PromptingTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ include("llm_interface.jl")
include("user_preferences.jl")

## Conversation history / Prompt elements
export AIMessage
# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only
include("messages.jl")
include("memory.jl")

# Export message types and predicates
export SystemMessage, UserMessage, AIMessage, AnnotationMessage, issystemmessage, isusermessage, isaimessage, isabstractannotationmessage, annotate!
# Export memory-related functionality
export ConversationMemory, get_last, last_message, last_output
# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only

export aitemplates, AITemplate
include("templates.jl")
Expand Down
5 changes: 4 additions & 1 deletion src/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ function render(schema::AbstractAnthropicSchema,
no_system_message::Bool = false,
cache::Union{Nothing, Symbol} = nothing,
kwargs...)
##
##
@assert count(issystemmessage, messages)<=1 "AbstractAnthropicSchema only supports at most 1 System message"
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last`, `:all` are supported for Anthropic Prompt Caching"

# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)

system = nothing

## First pass: keep the message types but make the replacements provided in `kwargs`
Expand Down
3 changes: 3 additions & 0 deletions src/llm_google.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ function render(schema::AbstractGoogleSchema,
no_system_message::Bool = false,
kwargs...)
##
# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)

## First pass: keep the message types but make the replacements provided in `kwargs`
messages_replaced = render(
NoSchema(), messages; conversation, no_system_message, kwargs...)
Expand Down
13 changes: 10 additions & 3 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ struct OpenAISchema <: AbstractOpenAISchema end

"Echoes the user's input back to them. Used for testing the implementation"
@kwdef mutable struct TestEchoOpenAISchema <: AbstractOpenAISchema
response::AbstractDict
status::Integer
response::AbstractDict = Dict(
"choices" => [Dict("message" => Dict("content" => "Test response", "role" => "assistant"), "index" => 0, "finish_reason" => "stop")],
"usage" => Dict("prompt_tokens" => 10, "completion_tokens" => 20, "total_tokens" => 30),
"model" => "gpt-3.5-turbo",
"id" => "test-id",
"object" => "chat.completion",
"created" => 1234567890
)
status::Integer = 200
model_id::String = ""
inputs::Any = nothing
end

"""
CustomOpenAISchema
CustomOpenAISchema
CustomOpenAISchema() allows user to call any OpenAI-compatible API.
Expand Down
5 changes: 4 additions & 1 deletion src/llm_ollama.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ function render(schema::AbstractOllamaSchema,
no_system_message::Bool = false,
kwargs...)
##
# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)

## First pass: keep the message types but make the replacements provided in `kwargs`
messages_replaced = render(
NoSchema(), messages; conversation, no_system_message, kwargs...)
Expand Down Expand Up @@ -376,4 +379,4 @@ end
function aitools(prompt_schema::AbstractOllamaSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("Managed schema does not support aitools. Please use OpenAISchema instead.")
end
end
6 changes: 5 additions & 1 deletion src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ function render(schema::AbstractOpenAISchema,
kwargs...)
##
@assert image_detail in ["auto", "high", "low"] "Image detail must be one of: auto, high, low"

# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)

## First pass: keep the message types but make the replacements provided in `kwargs`
messages_replaced = render(
NoSchema(), messages; conversation, no_system_message, kwargs...)
Expand Down Expand Up @@ -1733,4 +1737,4 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP
kwargs...)

return output
end
end
3 changes: 3 additions & 0 deletions src/llm_shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ function render(schema::NoSchema,
count_system_msg = count(issystemmessage, conversation)
# TODO: concat multiple system messages together (2nd pass)

# Filter out annotation messages from input messages
messages = filter(!isabstractannotationmessage, messages)

# replace any handlebar variables in the messages
for msg in messages
if issystemmessage(msg) || isusermessage(msg) || isusermessagewithimages(msg)
Expand Down
81 changes: 51 additions & 30 deletions src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,34 +86,43 @@ function get_last(mem::ConversationMemory, n::Integer=20;
# Always include system message and first user message
system_idx = findfirst(issystemmessage, messages)
first_user_idx = findfirst(isusermessage, messages)
result = AbstractMessage[]

# Initialize result with required messages
result = AbstractMessage[]
if !isnothing(system_idx)
push!(result, messages[system_idx])
end
if !isnothing(first_user_idx)
push!(result, messages[first_user_idx])
end

# Calculate remaining message budget
remaining_budget = n - length(result)
# Get remaining messages excluding system and first user
exclude_indices = filter(!isnothing, [system_idx, first_user_idx])
remaining_msgs = messages[setdiff(1:length(messages), exclude_indices)]

if remaining_budget > 0
if !isnothing(batch_size)
# Calculate how many complete batches we can include
total_msgs = length(messages)
num_batches = (total_msgs - length(result)) ÷ batch_size

# We want to keep between batch_size+1 and 2*batch_size messages
# If we would exceed 2*batch_size, reset to batch_size+1
if num_batches * batch_size > 2 * batch_size
num_batches = 1 # Reset to one batch (batch_size+1 messages)
end

start_idx = max(1, total_msgs - (num_batches * batch_size) + 1)
append!(result, messages[start_idx:end])
# Calculate how many messages to include based on batch size
if !isnothing(batch_size)
# When batch_size=10, should return between 11-20 messages
total_msgs = length(remaining_msgs)
num_batches = ceil(Int, total_msgs / batch_size)

# Calculate target size (between batch_size+1 and 2*batch_size)
target_size = if num_batches * batch_size > n
batch_size + 1 # Reset to minimum (11 for batch_size=10)
else
append!(result, messages[max(1, end-remaining_budget+1):end])
min(num_batches * batch_size, n - length(result))
end

# Get messages to append
if !isempty(remaining_msgs)
start_idx = max(1, length(remaining_msgs) - target_size + 1)
append!(result, remaining_msgs[start_idx:end])
end
else
# Without batch size, just get the last n-length(result) messages
if !isempty(remaining_msgs)
start_idx = max(1, length(remaining_msgs) - (n - length(result)) + 1)
append!(result, remaining_msgs[start_idx:end])
end
end

Expand All @@ -126,12 +135,13 @@ function get_last(mem::ConversationMemory, n::Integer=20;
end
end

# Add explanation if requested
if explain && length(messages) > n
ai_msg_idx = findfirst(isaimessage, result)
# Add explanation if requested and we truncated messages
if explain && length(messages) > length(result)
# Find first AI message in result after required messages
ai_msg_idx = findfirst(m -> isaimessage(m) && !(m in result[1:length(exclude_indices)]), result)
if !isnothing(ai_msg_idx)
orig_content = result[ai_msg_idx].content
explanation = "For efficiency reasons, we have truncated the preceding $(length(messages) - n) messages.\n\n$orig_content"
explanation = "For efficiency reasons, we have truncated the preceding $(length(messages) - length(result)) messages.\n\n$orig_content"
result[ai_msg_idx] = AIMessage(explanation)
end
end
Expand All @@ -146,16 +156,22 @@ Smart append that handles duplicate messages based on run IDs.
Only appends messages that are newer than the latest matching message in memory.
"""
function Base.append!(mem::ConversationMemory, msgs::Vector{<:AbstractMessage})
if isempty(mem.conversation) || isempty(msgs)
isempty(msgs) && return mem

if isempty(mem.conversation)
append!(mem.conversation, msgs)
return mem
end

# Find latest common message based on run_id
# Default to 0 if run_id is not defined
latest_run_id = maximum(msg -> isdefined(msg, :run_id) ? msg.run_id : 0, mem.conversation)
# Find latest run_id in memory
latest_run_id = 0
for msg in mem.conversation
if isdefined(msg, :run_id)
latest_run_id = max(latest_run_id, msg.run_id)
end
end

# Only append messages with higher run_id or no run_id
# Keep messages that either don't have a run_id or have a higher run_id
new_msgs = filter(msgs) do msg
!isdefined(msg, :run_id) || msg.run_id > latest_run_id
end
Expand Down Expand Up @@ -187,8 +203,12 @@ function (mem::ConversationMemory)(prompt::String; last::Union{Nothing,Integer}=
get_last(mem, last)
end

# Add user message to memory first
user_msg = UserMessage(prompt)
push!(mem.conversation, user_msg)

# Generate response with context
response = PromptingTools.aigenerate(context, prompt; kwargs...)
response = aigenerate(context, prompt; kwargs...)
push!(mem.conversation, response)
return response
end
Expand All @@ -198,6 +218,7 @@ end
Generate a response using the conversation memory context.
"""
function PromptingTools.aigenerate(mem::ConversationMemory, prompt::String; kwargs...)
PromptingTools.aigenerate(mem.conversation, prompt; kwargs...)
function PromptingTools.aigenerate(messages::Vector{<:AbstractMessage}, prompt::String; kwargs...)
schema = get(kwargs, :schema, OpenAISchema())
aigenerate(schema, [messages..., UserMessage(prompt)]; kwargs...)
end
Loading

0 comments on commit 3e9e4f1

Please sign in to comment.