Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature :all_but_last cache #254

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 52 additions & 33 deletions src/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@ Builds a history of the conversation to provide the prompt to the API. All unspe
- `aiprefill`: A string to be used as a prefill for the AI response. This steer the AI response in a certain direction (and potentially save output tokens).
- `conversation`: Past conversation to be included in the beginning of the prompt (for continued conversations).
- `no_system_message`: If `true`, do not include the default system message in the conversation history OR convert any provided system message to a user message.
- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last` and `:all` are supported.
- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last`, `:all_but_last`, and `:all` are supported.
- `:system`: Mark only the system message as cacheable. Best default if you have large system message and you will be sending short conversations (no replies / multi-turn conversations).
- `:all`: Mark SYSTEM, one before last and LAST user message as cacheable. Best for multi-turn conversations (you write cache point as "last" and it will be read in the next turn as "preceding" cache mark).
- `:last`: Mark only the last message as cacheable. Use ONLY if you want to send the SAME REQUEST multiple times (and want to save upto the last USER message). This will not work for multi-turn conversations, as the "last" message keeps moving.
- `:all_but_last`: Mark SYSTEM and one before LAST USER message. Use if you have a longer conversation that you want to re-use, but you will NOT CONTINUE it (no subsequent messages/follow-ups).
- In short, use `:all` for multi-turn conversations, `:system` for repeated single-turn conversations with same system message, and `:all_but_last` for longer conversations that you want to re-use, but not continue.
"""
function render(schema::AbstractAnthropicSchema,
messages::Vector{<:AbstractMessage};
aiprefill::Union{Nothing, AbstractString} = nothing,
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
conversation_msgs::AbstractVector{<:AbstractMessage} = AbstractMessage[],
no_system_message::Bool = false,
cache::Union{Nothing, Symbol} = nothing,
kwargs...)
##
@assert count(issystemmessage, messages)<=1 "AbstractAnthropicSchema only supports at most 1 System message"
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last`, `:all` are supported for Anthropic Prompt Caching"
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all, :all_but_last]) "Currently only `:system`, `:tools`, `:last`, `all_but_last`, `:all` are supported for Anthropic Prompt Caching (cache=$cache)"

# Filter out annotation messages before any processing
messages = filter(!isabstractannotationmessage, messages)
Expand All @@ -39,7 +44,7 @@ function render(schema::AbstractAnthropicSchema,

## First pass: keep the message types but make the replacements provided in `kwargs`
messages_replaced = render(
NoSchema(), messages; conversation, no_system_message, kwargs...)
NoSchema(), messages; conversation_msgs, no_system_message, kwargs...)

## Second pass: convert to the message-based schema
conversation = Dict{String, Any}[]
Expand Down Expand Up @@ -76,25 +81,32 @@ function render(schema::AbstractAnthropicSchema,
# Note: Ignores any DataMessage or other types
end

## Add Tool definitions to the System Prompt
# if !isempty(tools)
# ANTHROPIC_TOOL_SUFFIX = "Use the $(tools[1][:name]) tool in your response."
# ## Add to system message
# if isnothing(system)
# system = ANTHROPIC_TOOL_SUFFIX
# else
# system *= "\n\n" * ANTHROPIC_TOOL_SUFFIX
# end
# end

## Note: For cache to work, it must be marked in the same location across calls!
## Apply cache for last message
is_valid_conversation = length(conversation) > 0 &&
haskey(conversation[end], "content") &&
length(conversation[end]["content"]) > 0
if is_valid_conversation && (cache == :last || cache == :all)
conversation[end]["content"][end]["cache_control"] = Dict("type" => "ephemeral")
user_msg_counter = 0
if is_valid_conversation
for i in reverse(eachindex(conversation))
## we mark only user messages
# Cache points must be EXACTLY at the same location across calls!
if conversation[i]["role"] == "user"
if cache == :last && user_msg_counter == 0 # marks exactly once
conversation[i]["content"][end]["cache_control"] = Dict("type" => "ephemeral")
elseif cache == :all && user_msg_counter < 2 # marks twice - for 0 and 1
# Mark the last AND preceding user message!
# If we don't do this, then next time we call it with a new message, the cache points will not overlap!
conversation[i]["content"][end]["cache_control"] = Dict("type" => "ephemeral")
elseif cache == :all_but_last && user_msg_counter == 1 # marks once, only the preceding user message
conversation[i]["content"][end]["cache_control"] = Dict("type" => "ephemeral")
end
user_msg_counter += 1
end
end
end
if !no_system_message && !isnothing(system) && (cache == :system || cache == :all)
if !no_system_message && !isnothing(system) &&
(cache == :system || cache == :all || cache == :all_but_last)
## Apply cache for system message
system = [Dict("type" => "text", "text" => system,
"cache_control" => Dict("type" => "ephemeral"))]
Expand Down Expand Up @@ -261,7 +273,7 @@ Simple wrapper for a call to Anthropic API.
- `http_kwargs::NamedTuple`: Additional keyword arguments for the HTTP request. Defaults to empty `NamedTuple`.
- `stream`: A boolean indicating whether to stream the response. Defaults to `false`.
- `url`: The URL of the Ollama API. Defaults to "localhost".
- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last` and `:all` are supported.
- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last`, `:all_but_last`, and `:all` are supported.
- `betas`: A vector of symbols representing the beta features to be used. Currently only `:tools` and `:cache` are supported.
- `kwargs`: Prompt variables to be used to fill the prompt/template
"""
Expand Down Expand Up @@ -355,11 +367,12 @@ Generate an AI response based on a given prompt using the Anthropic API.
- `http_kwargs::NamedTuple`: Additional keyword arguments for the HTTP request. Defaults to empty `NamedTuple`.
- `api_kwargs::NamedTuple`: Additional keyword arguments for the Ollama API. Defaults to an empty `NamedTuple`.
- `max_tokens::Int`: The maximum number of tokens to generate. Defaults to 2048, because it's a required parameter for the API.
- `cache`: A symbol indicating whether to use caching for the prompt. Supported values are `nothing` (no caching), `:system`, `:tools`, `:last` and `:all`. Note that COST estimate will be wrong (ignores the caching).
- `:system`: Caches the system message
- `:tools`: Caches the tool definitions (and everything before them)
- `:last`: Caches the last message in the conversation (and everything before it)
- `:all`: Cache trigger points are inserted in all of the above places (ie, higher likelyhood of cache hit, but also slightly higher cost)
- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last`, `:all_but_last` and `:all` are supported. Note that COST estimate will be wrong (ignores the caching).
- `:system`: Mark only the system message as cacheable. Best default if you have large system message and you will be sending short conversations (no replies / multi-turn conversations).
- `:all`: Mark SYSTEM, one before last and LAST user message as cacheable. Best for multi-turn conversations (you write cache point as "last" and it will be read in the next turn as "preceding" cache mark).
- `:last`: Mark only the last message as cacheable. Use ONLY if you want to send the SAME REQUEST multiple times (and want to save upto the last USER message). This will not work for multi-turn conversations, as the "last" message keeps moving.
- `:all_but_last`: Mark SYSTEM and one before LAST USER message. Use if you have a longer conversation that you want to re-use, but you will NOT CONTINUE it (no subsequent messages/follow-ups).
- In short, use `:all` for multi-turn conversations, `:system` for repeated single-turn conversations with same system message, and `:all_but_last` for longer conversations that you want to re-use, but not continue.
- `betas::Union{Nothing, Vector{Symbol}}`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details.
- `kwargs`: Prompt variables to be used to fill the prompt/template

Expand Down Expand Up @@ -457,7 +470,7 @@ function aigenerate(
kwargs...)
##
global MODEL_ALIASES
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last` and `:all` are supported for Anthropic Prompt Caching"
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all, :all_but_last]) "Currently only `:system`, `:tools`, `:last`, `all_but_last` and `:all` are supported for Anthropic Prompt Caching (cache=$cache)"
@assert (isnothing(aiprefill)||!isempty(strip(aiprefill))) "`aiprefill` must not be empty`"
## Find the unique ID for the model alias provided
model_id = get(MODEL_ALIASES, model, model)
Expand Down Expand Up @@ -552,11 +565,12 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi
- `http_kwargs`: A named tuple of HTTP keyword arguments.
- `api_kwargs`: A named tuple of API keyword arguments.
- `:tool_choice`: A string indicating which tool to use. Supported values are `nothing`, `"auto"`, `"any"` and `"exact"`. `nothing` will use the default tool choice.
- `cache`: A symbol indicating whether to use caching for the prompt. Supported values are `nothing` (no caching), `:system`, `:tools`, `:last` and `:all`. Note that COST estimate will be wrong (ignores the caching).
- `:system`: Caches the system message
- `:tools`: Caches the tool definitions (and everything before them)
- `:last`: Caches the last message in the conversation (and everything before it)
- `:all`: Cache trigger points are inserted in all of the above places (ie, higher likelyhood of cache hit, but also slightly higher cost)
- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last`, `:all_but_last`, and `:all` are supported. Note: COST estimate will be wrong (ignores the caching).
- `:system`: Mark only the system message as cacheable. Best default if you have large system message and you will be sending short conversations (no replies / multi-turn conversations).
- `:all`: Mark SYSTEM, one before last and LAST user message as cacheable. Best for multi-turn conversations (you write cache point as "last" and it will be read in the next turn as "preceding" cache mark).
- `:last`: Mark only the last message as cacheable. Use ONLY if you want to send the SAME REQUEST multiple times (and want to save upto the last USER message). This will not work for multi-turn conversations, as the "last" message keeps moving.
- `:all_but_last`: Mark SYSTEM and one before LAST USER message. Use if you have a longer conversation that you want to re-use, but you will NOT CONTINUE it (no subsequent messages/follow-ups).
- In short, use `:all` for multi-turn conversations, `:system` for repeated single-turn conversations with same system message, and `:all_but_last` for longer conversations that you want to re-use, but not continue.
- `betas::Union{Nothing, Vector{Symbol}}`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details.
- `kwargs`: Prompt variables to be used to fill the prompt/template

Expand Down Expand Up @@ -690,7 +704,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP
kwargs...)
##
global MODEL_ALIASES
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last` and `:all` are supported for Anthropic Prompt Caching"
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all_but_last, :all]) "Currently only `:system`, `:tools`, `:last`, `:all_but_last` and `:all` are supported for Anthropic Prompt Caching"

## Check that no functions or methods are provided, that is not supported
@assert !(return_type isa Vector)||!any(x -> x isa Union{Function, Method}, return_type) "Functions and Methods are not supported in `aiextract`!"
Expand Down Expand Up @@ -815,7 +829,12 @@ Differences to `aiextract`: Can provide infinitely many tools (including Functio
- `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history.
- `no_system_message::Bool = false`: Whether to exclude the system message from the conversation history.
- `image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing`: A path to a local image file, or a vector of paths to local image files. Always attaches images to the latest user message.
- `cache::Union{Nothing, Symbol} = nothing`: Whether to cache the prompt. Defaults to `nothing`.
- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last`, `:all_but_last`, and `:all` are supported. Note: COST estimate will be wrong (ignores the caching).
- `:system`: Mark only the system message as cacheable. Best default if you have large system message and you will be sending short conversations (no replies / multi-turn conversations).
- `:all`: Mark SYSTEM, one before last and LAST user message as cacheable. Best for multi-turn conversations (you write cache point as "last" and it will be read in the next turn as "preceding" cache mark).
- `:last`: Mark only the last message as cacheable. Use ONLY if you want to send the SAME REQUEST multiple times (and want to save upto the last USER message). This will not work for multi-turn conversations, as the "last" message keeps moving.
- `:all_but_last`: Mark SYSTEM and one before LAST USER message. Use if you have a longer conversation that you want to re-use, but you will NOT CONTINUE it (no subsequent messages/follow-ups).
- In short, use `:all` for multi-turn conversations, `:system` for repeated single-turn conversations with same system message, and `:all_but_last` for longer conversations that you want to re-use, but not continue.
- `betas::Union{Nothing, Vector{Symbol}} = nothing`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details.
- `http_kwargs`: A named tuple of HTTP keyword arguments.
- `api_kwargs`: A named tuple of API keyword arguments. Several important arguments are highlighted below:
Expand Down Expand Up @@ -897,7 +916,7 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_
tool_choice = nothing),
kwargs...)
global MODEL_ALIASES
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last` and `:all` are supported for Anthropic Prompt Caching"
@assert (isnothing(cache)||cache in [:system, :tools, :last, :all_but_last, :all]) "Currently only `:system`, `:tools`, `:last`, `:all_but_last` and `:all` are supported for Anthropic Prompt Caching"

## Find the unique ID for the model alias provided
model_id = get(MODEL_ALIASES, model, model)
Expand Down
60 changes: 59 additions & 1 deletion test/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,23 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature,
"cache_control" => Dict("type" => "ephemeral"))])])
@test conversation == expected_output

## We mark only user messages
messages_with_ai = [
SystemMessage("Act as a helpful AI assistant"),
UserMessage("Hello, my name is {{name}}"),
AIMessage("Hi there")
]
conversation = render(schema, messages_with_ai; name = "John", cache = :last)
expected_output = (;
system = "Act as a helpful AI assistant",
conversation = [
Dict("role" => "user",
"content" => [Dict("type" => "text", "text" => "Hello, my name is John",
"cache_control" => Dict("type" => "ephemeral"))]),
Dict("role" => "assistant",
"content" => [Dict("type" => "text", "text" => "Hi there")])])
@test conversation == expected_output

conversation = render(schema, messages; name = "John", cache = :all)
expected_output = (;
system = Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"),
Expand All @@ -183,7 +200,48 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature,
"cache_control" => Dict("type" => "ephemeral"))])])
@test conversation == expected_output

# Test aiprefill functionality
conversation = render(schema, messages_with_ai; name = "John", cache = :all)
expected_output = (;
system = Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"),
"text" => "Act as a helpful AI assistant", "type" => "text")],
conversation = [
Dict("role" => "user",
"content" => [Dict("type" => "text", "text" => "Hello, my name is John",
"cache_control" => Dict("type" => "ephemeral"))]),
Dict("role" => "assistant",
"content" => [Dict("type" => "text", "text" => "Hi there")])])
@test conversation == expected_output

## Longer conversation
messages_longer = [
SystemMessage("Act as a helpful AI assistant"),
UserMessage("Hello, my name is {{name}}"),
AIMessage("Hi there"),
UserMessage("How are you?"),
AIMessage("I'm doing well, thank you!")
]
system, conversation = render(schema, messages_longer; name = "John", cache = :all)
## marks last user message
@test conversation[end - 1]["content"][end]["cache_control"] ==
Dict("type" => "ephemeral")
## marks one before last user message
@test conversation[end - 3]["content"][end]["cache_control"] ==
Dict("type" => "ephemeral")
## marks system message
@test system[1]["cache_control"] == Dict("type" => "ephemeral")

## all_but_last
system, conversation = render(
schema, messages_longer; name = "John", cache = :all_but_last)
## does not mark last user message
@test !haskey(conversation[end - 1]["content"][end], "cache_control")
## marks one before last user message
@test conversation[end - 3]["content"][end]["cache_control"] ==
Dict("type" => "ephemeral")
## marks system message
@test system[1]["cache_control"] == Dict("type" => "ephemeral")

### aiprefill functionality
messages = [
SystemMessage("Act as a helpful AI assistant"),
UserMessage("Hello, what's your name?")
Expand Down
Loading