Skip to content

Commit

Permalink
Add image support to aitools (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Nov 17, 2024
1 parent 4f07dd4 commit 6470ad9
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 12 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.64.0]

### Added
- Added support for images in `aitools` to enable passing screenshots via `image_path` argument (extended to both OpenAI and Anthropic APIs, uses `?UserMessageWithImages` internally).
- Added the latest Gemini Experimental model via OpenAI compatibility mode (`gemini-exp-1114` with alias `gemexp`).

## [0.63.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.63.0"
version = "0.64.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
30 changes: 26 additions & 4 deletions src/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,31 @@ function render(schema::AbstractAnthropicSchema,
conversation = Dict{String, Any}[]

for msg in messages_replaced
if msg isa SystemMessage
if issystemmessage(msg)
system = msg.content
elseif msg isa UserMessage || msg isa AIMessage
elseif isusermessage(msg) || isaimessage(msg)
content = msg.content
push!(conversation,
Dict("role" => role4render(schema, msg),
"content" => [Dict{String, Any}("type" => "text", "text" => content)]))
elseif msg isa UserMessageWithImages
error("AbstractAnthropicSchema does not yet support UserMessageWithImages. Please use OpenAISchema instead.")
elseif isusermessagewithimages(msg)
# Build message content
content = Dict{String, Any}[Dict("type" => "text",
"text" => msg.content)]
# Add images
for img in msg.image_url
# image_url = "data:image/$image_suffix;base64,$(base64_image)"
data_type, data = extract_image_attributes(img)
@assert data_type in ["image/jpeg", "image/png", "image/gif", "image/webp"] "Unsupported image type: $data_type"
push!(content,
Dict("type" => "image",
"source" => Dict("type" => "base64",
"data" => data,
## image/jpeg, image/png, image/gif, image/webp
"media_type" => data_type)))
end
push!(conversation,
Dict("role" => role4render(schema, msg), "content" => content))
end
# Note: Ignores any DataMessage or other types
end
Expand Down Expand Up @@ -766,6 +782,7 @@ end
return_all::Bool = false, dry_run::Bool = false,
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
no_system_message::Bool = false,
image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing,
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
http_kwargs::NamedTuple = (retry_non_idempotent = true,
Expand All @@ -792,6 +809,7 @@ Differences to `aiextract`: Can provide infinitely many tools (including Functio
- `dry_run`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`).
- `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history.
- `no_system_message::Bool = false`: Whether to exclude the system message from the conversation history.
- `image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing`: A path to a local image file, or a vector of paths to local image files. Always attaches images to the latest user message.
- `cache::Union{Nothing, Symbol} = nothing`: Whether to cache the prompt. Defaults to `nothing`.
- `betas::Union{Nothing, Vector{Symbol}} = nothing`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details.
- `http_kwargs`: A named tuple of HTTP keyword arguments.
Expand Down Expand Up @@ -865,6 +883,7 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_
return_all::Bool = false, dry_run::Bool = false,
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
no_system_message::Bool = false,
image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing,
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
http_kwargs::NamedTuple = (retry_non_idempotent = true,
Expand Down Expand Up @@ -899,6 +918,9 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_
## Add the function call stopping sequence to the api_kwargs
api_kwargs = merge(api_kwargs, (; tools, tool_choice))

## Vision-specific functionality -- if `image_path` is provided, attach images to the latest user message
!isnothing(image_path) &&
(prompt = attach_images_to_user_message(prompt; image_path, attach_to_latest = true))
## We provide the tool description to the rendering engine
conv_rendered = render(
prompt_schema, prompt; tools, conversation, no_system_message, cache, kwargs...)
Expand Down
9 changes: 7 additions & 2 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,7 @@ end
return_all::Bool = false, dry_run::Bool = false,
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
no_system_message::Bool = false,
image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing,
http_kwargs::NamedTuple = (retry_non_idempotent = true,
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = (;
Expand Down Expand Up @@ -1575,6 +1576,7 @@ Differences to `aiextract`: Can provide infinitely many tools (including Functio
- `dry_run`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`).
- `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history.
- `no_system_message::Bool = false`: Whether to exclude the system message from the conversation history.
- `image_path`: A path to a local image file, or a vector of paths to local image files. Always attaches images to the latest user message.
- `name_user`: The name of the user in the conversation history. Defaults to "User".
- `name_assistant`: The name of the assistant in the conversation history. Defaults to "Assistant".
- `http_kwargs`: A named tuple of HTTP keyword arguments.
Expand Down Expand Up @@ -1641,8 +1643,8 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP
model::String = MODEL_CHAT,
return_all::Bool = false, dry_run::Bool = false,
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
no_system_message::Bool = false,
name_user::Union{Nothing, String} = nothing,
no_system_message::Bool = false, name_user::Union{Nothing, String} = nothing,
image_path::Union{Nothing, AbstractString, Vector{<:AbstractString}} = nothing,
name_assistant::Union{Nothing, String} = nothing,
http_kwargs::NamedTuple = (retry_non_idempotent = true,
retries = 5,
Expand Down Expand Up @@ -1685,6 +1687,9 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP

## Find the unique ID for the model alias provided
model_id = get(MODEL_ALIASES, model, model)
## Vision-specific functionality -- if `image_path` is provided, attach images to the latest user message
!isnothing(image_path) &&
(prompt = attach_images_to_user_message(prompt; image_path, attach_to_latest = true))
## Render the conversation history from messages
conv_rendered = render(
prompt_schema, prompt; conversation, no_system_message, name_user, kwargs...)
Expand Down
12 changes: 9 additions & 3 deletions src/user_preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ aliases = merge(
## Gemini 1.5 Models
"gem15p" => "gemini-1.5-pro-latest",
"gem15f8" => "gemini-1.5-flash-8b-latest",
"gem15f" => "gemini-1.5-flash-latest"
"gem15f" => "gemini-1.5-flash-latest",
"gemexp" => "gemini-exp-1114" # latest experimental model from November 2024
),
## Load aliases from preferences as well
@load_preference("MODEL_ALIASES", default=Dict{String, String}()))
Expand Down Expand Up @@ -1111,7 +1112,7 @@ registry = Dict{String, ModelSpec}(
## Gemini 1.5 Models
"gemini-1.5-pro-latest" => ModelSpec("gemini-1.5-pro-latest",
GoogleOpenAISchema(),
1e-6,
1.25e-6,
5e-6,
"Gemini 1.5 Pro is Google's latest large language model with enhanced capabilities across reasoning, math, coding, and multilingual tasks. 128K context window."),
"gemini-1.5-flash-8b-latest" => ModelSpec("gemini-1.5-flash-8b-latest",
Expand All @@ -1123,7 +1124,12 @@ registry = Dict{String, ModelSpec}(
GoogleOpenAISchema(),
7.5e-8,
3.0e-7,
"Gemini 1.5 Flash is a high-performance model optimized for speed while maintaining strong capabilities across various tasks. 128K context window.")
"Gemini 1.5 Flash is a high-performance model optimized for speed while maintaining strong capabilities across various tasks. 128K context window."),
"gemini-exp-1114" => ModelSpec("gemini-exp-1114",
GoogleOpenAISchema(),
1.25e-6,
5e-6,
"Gemini Experimental Model from November 2024. Pricing assumed as per Gemini 1.5 Pro. See details [here](https://ai.google.dev/gemini-api/docs/models/experimental-models#use-an-experimental-model).")
)

"""
Expand Down
31 changes: 31 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -668,4 +668,35 @@ Returns indices of unique items in a vector `inputs`. Access the unique values a
"""
function unique_permutation(inputs::AbstractVector)
return unique(i -> inputs[i], eachindex(inputs))
end

"""
extract_image_attributes(image_url::AbstractString) -> Tuple{String, String}
Extracts the data type and base64-encoded data from a data URL.
# Arguments
- `image_url::AbstractString`: The data URL to be parsed.
# Returns
`Tuple{String, String}`: A tuple containing the data type (e.g., `"image/png"`) and the base64-encoded data.
# Example
```julia
image_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABQAA"
data_type, data = extract_data_type_and_data(image_url)
# data_type == "image/png"
# data == "iVBORw0KGgoAAAANSUhEUgAABQAA"
```
"""
function extract_image_attributes(image_url::AbstractString)::Tuple{String, String}
pattern = r"^data:(.*?);base64,(.*)$"
m = match(pattern, image_url)
if m !== nothing
data_type = m.captures[1]
data = m.captures[2]
return data_type, data
else
throw(ArgumentError("Invalid data URL format"))
end
end
26 changes: 25 additions & 1 deletion test/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,38 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature,
conversation = render(schema, messages)
@test conversation == expected_output

### IMAGES
# Test UserMessageWithImages -- errors for now
messages = [
SystemMessage("System message 1"),
UserMessageWithImages("User message"; image_url = "https://example.com/image.png")
]
## We don't support URL format!
@test_throws Exception render(schema, messages)

## Tool calling
## Unsupported format
messages = [
SystemMessage("System message 1"),
UserMessageWithImages(
"User message"; image_url = "data:image/svg;base64,iVBORw0KGgoAAAANSUhEUgAABQAA")
]
@test_throws AssertionError render(schema, messages)

## Base64 format
messages = [
SystemMessage("System message 1"),
UserMessageWithImages(
"User message"; image_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABQAA")
]
rendered = render(schema, messages)
@test rendered.conversation[1] == Dict{String, Any}("role" => "user",
"content" => Dict{String, Any}[Dict("text" => "User message", "type" => "text"),
Dict(
"source" => Dict("media_type" => "image/png",
"data" => "iVBORw0KGgoAAAANSUhEUgAABQAA", "type" => "base64"),
"type" => "image")])

### Tool calling
"abc"
struct FruitCountX
fruit::String
Expand Down
21 changes: 20 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using PromptingTools: recursive_splitter, wrap_string, replace_words,
length_longest_common_subsequence, distance_longest_common_subsequence
using PromptingTools: _extract_handlebar_variables, call_cost, call_cost_alternative,
_report_stats
using PromptingTools: _string_to_vector, _encode_local_image
using PromptingTools: _string_to_vector, _encode_local_image, extract_image_attributes
using PromptingTools: DataMessage, AIMessage, UserMessage
using PromptingTools: push_conversation!,
resize_conversation!, @timeout, preview, pprint, auth_header,
Expand Down Expand Up @@ -276,6 +276,25 @@ end
@test _encode_local_image(nothing) == String[]
end

@testset "extract_image_attributes" begin
# Test basic valid data URL
data_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABQAA"
data_type, data = extract_image_attributes(data_url)
@test data_type == "image/png"
@test data == "iVBORw0KGgoAAAANSUhEUgAABQAA"

# Test different image type
data_url = "data:image/jpeg;base64,/9j/4AAQSkZJRg"
data_type, data = extract_image_attributes(data_url)
@test data_type == "image/jpeg"
@test data == "/9j/4AAQSkZJRg"

# Test invalid data URL format
@test_throws ArgumentError extract_image_attributes("not a data url")
@test_throws ArgumentError extract_image_attributes("data:image/png;")
@test_throws ArgumentError extract_image_attributes("data:image/png;base64")
end

### Conversation Management
@testset "push_conversation!,resize_conversation!" begin
# Test 1: Adding to Conversation History
Expand Down

0 comments on commit 6470ad9

Please sign in to comment.