Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConversationMemory and AnnotationMessage implementations #239

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StreamCallbacks = "c1b9e933-98a0-46fc-8ea7-3b58b195fb0a"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
Expand Down Expand Up @@ -53,10 +54,13 @@ PrecompileTools = "1"
Preferences = "1"
REPL = "<0.0.1, 1"
Random = "<0.0.1, 1"
Snowball = "0.1"
SparseArrays = "<0.0.1, 1"
Statistics = "<0.0.1, 1"
StreamCallbacks = "0.4, 0.5"
StructTypes = "1"
Test = "<0.0.1, 1"
Unicode = "<0.0.1, 1"
julia = "1.9, 1.10"

[extras]
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
DocumenterVitepress = "0.0.7"
PromptingTools = "0.65.0"
9 changes: 7 additions & 2 deletions src/PromptingTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ include("llm_interface.jl")
include("user_preferences.jl")

## Conversation history / Prompt elements
export AIMessage
# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only
include("messages.jl")
include("memory.jl")

# Export message types and predicates
export SystemMessage, UserMessage, AIMessage, AnnotationMessage, issystemmessage, isusermessage, isaimessage, isabstractannotationmessage, annotate!
# Export memory-related functionality
export ConversationMemory, get_last, last_message, last_output
# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only

export aitemplates, AITemplate
include("templates.jl")
Expand Down
5 changes: 4 additions & 1 deletion src/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ function render(schema::AbstractAnthropicSchema,
no_system_message::Bool = false,
cache::Union{Nothing, Symbol} = nothing,
kwargs...)
##
##
@assert count(issystemmessage, messages)<=1 "AbstractAnthropicSchema only supports at most 1 System message"
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last`, `:all` are supported for Anthropic Prompt Caching"

# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)

system = nothing

## First pass: keep the message types but make the replacements provided in `kwargs`
Expand Down
3 changes: 3 additions & 0 deletions src/llm_google.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ function render(schema::AbstractGoogleSchema,
no_system_message::Bool = false,
kwargs...)
##
# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)

## First pass: keep the message types but make the replacements provided in `kwargs`
messages_replaced = render(
NoSchema(), messages; conversation, no_system_message, kwargs...)
Expand Down
13 changes: 10 additions & 3 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ struct OpenAISchema <: AbstractOpenAISchema end

"Echoes the user's input back to them. Used for testing the implementation"
@kwdef mutable struct TestEchoOpenAISchema <: AbstractOpenAISchema
response::AbstractDict
status::Integer
response::AbstractDict = Dict(
"choices" => [Dict("message" => Dict("content" => "Test response", "role" => "assistant"), "index" => 0, "finish_reason" => "stop")],
"usage" => Dict("prompt_tokens" => 10, "completion_tokens" => 20, "total_tokens" => 30),
"model" => "gpt-3.5-turbo",
"id" => "test-id",
"object" => "chat.completion",
"created" => 1234567890
)
status::Integer = 200
model_id::String = ""
inputs::Any = nothing
end

"""
CustomOpenAISchema
CustomOpenAISchema

CustomOpenAISchema() allows user to call any OpenAI-compatible API.

Expand Down
5 changes: 4 additions & 1 deletion src/llm_ollama.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ function render(schema::AbstractOllamaSchema,
no_system_message::Bool = false,
kwargs...)
##
# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)

## First pass: keep the message types but make the replacements provided in `kwargs`
messages_replaced = render(
NoSchema(), messages; conversation, no_system_message, kwargs...)
Expand Down Expand Up @@ -376,4 +379,4 @@ end
function aitools(prompt_schema::AbstractOllamaSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("Managed schema does not support aitools. Please use OpenAISchema instead.")
end
end
6 changes: 5 additions & 1 deletion src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ function render(schema::AbstractOpenAISchema,
kwargs...)
##
@assert image_detail in ["auto", "high", "low"] "Image detail must be one of: auto, high, low"

# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)

## First pass: keep the message types but make the replacements provided in `kwargs`
messages_replaced = render(
NoSchema(), messages; conversation, no_system_message, kwargs...)
Expand Down Expand Up @@ -1733,4 +1737,4 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP
kwargs...)

return output
end
end
3 changes: 3 additions & 0 deletions src/llm_shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ function render(schema::NoSchema,
count_system_msg = count(issystemmessage, conversation)
# TODO: concat multiple system messages together (2nd pass)

# Filter out annotation messages from input messages
messages = filter(!isabstractannotationmessage, messages)

# replace any handlebar variables in the messages
for msg in messages
if issystemmessage(msg) || isusermessage(msg) || isusermessagewithimages(msg)
Expand Down
227 changes: 227 additions & 0 deletions src/memory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Conversation Memory Implementation
import PromptingTools: AbstractMessage, SystemMessage, AIMessage, UserMessage
import PromptingTools: issystemmessage, isusermessage, isaimessage
import PromptingTools: aigenerate, last_message, last_output

"""
ConversationMemory

A structured container for managing conversation history with intelligent truncation
and caching capabilities.

The memory supports batched retrieval with deterministic truncation points for optimal
caching behavior.
"""
Base.@kwdef mutable struct ConversationMemory
conversation::Vector{AbstractMessage} = AbstractMessage[]
end

# Basic interface extensions
import Base: show, push!, append!, length

"""
show(io::IO, mem::ConversationMemory)

Display the number of non-system messages in the conversation memory.
"""
function Base.show(io::IO, mem::ConversationMemory)
n_msgs = count(!issystemmessage, mem.conversation)
print(io, "ConversationMemory($(n_msgs) messages)")
end

"""
length(mem::ConversationMemory)

Return the number of messages, excluding system messages.
"""
function Base.length(mem::ConversationMemory)
count(!issystemmessage, mem.conversation)
end

"""
last_message(mem::ConversationMemory)

Get the last message in the conversation, delegating to PromptingTools.last_message.
"""
function last_message(mem::ConversationMemory)
PromptingTools.last_message(mem.conversation)
end

"""
last_output(mem::ConversationMemory)

Get the last AI message in the conversation, delegating to PromptingTools.last_output.
"""
function last_output(mem::ConversationMemory)
PromptingTools.last_output(mem.conversation)
end

"""
get_last(mem::ConversationMemory, n::Integer=20;
batch_size::Union{Nothing,Integer}=nothing,
verbose::Bool=false,
explain::Bool=false)

Get the last n messages with intelligent batching and caching support.

Arguments:
- n::Integer: Maximum number of messages to return (default: 20)
- batch_size::Union{Nothing,Integer}: If provided, ensures messages are truncated in fixed batches
- verbose::Bool: Print detailed information about truncation
- explain::Bool: Add explanation about truncation in the response

Returns:
Vector{AbstractMessage} with the selected messages, always including:
1. The system message (if present)
2. First user message
3. Messages up to n, respecting batch_size boundaries
"""
function get_last(mem::ConversationMemory, n::Integer=20;
batch_size::Union{Nothing,Integer}=nothing,
verbose::Bool=false,
explain::Bool=false)
messages = mem.conversation
isempty(messages) && return AbstractMessage[]

# Always include system message and first user message
system_idx = findfirst(issystemmessage, messages)
first_user_idx = findfirst(isusermessage, messages)

# Initialize result with required messages
result = AbstractMessage[]
if !isnothing(system_idx)
push!(result, messages[system_idx])
end
if !isnothing(first_user_idx)
push!(result, messages[first_user_idx])
end

# Get remaining messages excluding system and first user
exclude_indices = filter(!isnothing, [system_idx, first_user_idx])
remaining_msgs = messages[setdiff(1:length(messages), exclude_indices)]

# Calculate how many additional messages to include (n minus required messages)
target_n = n - length(result)

if !isnothing(batch_size)
# When batch_size=10, should return between 11-20 messages
total_msgs = length(remaining_msgs)

# Calculate number of complete batches needed
num_batches = ceil(Int, total_msgs / batch_size)

# Target size should be a multiple of batch_size plus 1
target_size = if total_msgs > 2 * batch_size
batch_size + 1 # Reset to minimum (11 for batch_size=10)
else
min(total_msgs, target_n) # Keep all messages up to target_n
end

# Get messages to append
if !isempty(remaining_msgs)
start_idx = max(1, length(remaining_msgs) - target_size + 1)
append!(result, remaining_msgs[start_idx:end])
end
else
# Without batch size, just get the last target_n messages
if !isempty(remaining_msgs)
start_idx = max(1, length(remaining_msgs) - target_n + 1)
append!(result, remaining_msgs[start_idx:end])
end
end

if verbose
println("Total messages: ", length(messages))
println("Keeping: ", length(result))
println("Required messages: ", count(m -> issystemmessage(m) || m === messages[first_user_idx], result))
if !isnothing(batch_size)
println("Using batch size: ", batch_size)
end
end

# Add explanation if requested and we truncated messages
if explain && length(messages) > length(result)
ai_msg_idx = findfirst(isaimessage, result)
if !isnothing(ai_msg_idx)
orig_content = result[ai_msg_idx].content
explanation = "For efficiency reasons, we have truncated the preceding $(length(messages) - length(result)) messages.\n\n$orig_content"
result[ai_msg_idx] = AIMessage(explanation)
end
end

return result
end

"""
append!(mem::ConversationMemory, msgs::Vector{<:AbstractMessage})

Smart append that handles duplicate messages based on run IDs.
Only appends messages that are newer than the latest matching message in memory.
"""
function Base.append!(mem::ConversationMemory, msgs::Vector{<:AbstractMessage})
isempty(msgs) && return mem

if isempty(mem.conversation)
append!(mem.conversation, msgs)
return mem
end

# Find latest run_id in memory
latest_run_id = 0
for msg in mem.conversation
if isdefined(msg, :run_id)
latest_run_id = max(latest_run_id, msg.run_id)
end
end

# Keep messages that either don't have a run_id or have a higher run_id
new_msgs = filter(msgs) do msg
!isdefined(msg, :run_id) || msg.run_id > latest_run_id
end

append!(mem.conversation, new_msgs)
return mem
end

"""
push!(mem::ConversationMemory, msg::AbstractMessage)

Add a single message to the conversation memory.
"""
function Base.push!(mem::ConversationMemory, msg::AbstractMessage)
push!(mem.conversation, msg)
return mem
end

"""
(mem::ConversationMemory)(prompt::String; last::Union{Nothing,Integer}=nothing, kwargs...)

Functor interface for direct generation using the conversation memory.
"""
function (mem::ConversationMemory)(prompt::String; last::Union{Nothing,Integer}=nothing, kwargs...)
# Get conversation context
context = if isnothing(last)
mem.conversation
else
get_last(mem, last)
end

# Add user message to memory first
user_msg = UserMessage(prompt)
push!(mem.conversation, user_msg)

# Generate response with context
response = aigenerate(context, prompt; kwargs...)
push!(mem.conversation, response)
return response
end

"""
aigenerate(mem::ConversationMemory, prompt::String; kwargs...)

Generate a response using the conversation memory context.
"""
function PromptingTools.aigenerate(messages::Vector{<:AbstractMessage}, prompt::String; kwargs...)
schema = get(kwargs, :schema, OpenAISchema())
aigenerate(schema, [messages..., UserMessage(prompt)]; kwargs...)
end
Loading
Loading